In [None]:
# Enable CUDA launch blocking for better error traceability
%env CUDA_LAUNCH_BLOCKING=1

# Install required packages for Google Colab
%pip install transformers torch sentence-transformers --quiet
%pip install flask flask-cors pyngrok --quiet
%pip install supabase python-dotenv --quiet
%pip install sacremoses --quiet


In [None]:
import os
import json
import logging
import traceback
from datetime import datetime
from typing import List, Dict, Any, Optional
import torch
from getpass import getpass

# MedGemma-specific imports - Key difference from standard LLMs
from transformers import AutoProcessor, AutoModelForImageTextToText

# Supabase
from supabase import create_client

# Flask API
from flask import Flask, request, jsonify
from flask_cors import CORS

# Colab secrets
from google.colab import userdata

# Setup logging with enhanced verbosity for MedGemma debugging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print("üì¶ All packages imported successfully!")
print(f"üïê RAG session started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"üîß Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

# MedGemma model configuration - Specialized for multimodal medical tasks
MEDGEMMA_CONFIG = {
    "name": "MedGemma-4B",
    "path": "google/medgemma-4b-it",
    "description": "4B parameter multimodal medical model from Google",
    "type": "multimodal",
    "architecture": "ImageTextToText",
    "specialty": "Medical and healthcare domain"
}

print(f"\nü§ñ Selected Medical Model: {MEDGEMMA_CONFIG['name']}")
print(f"üì¶ Model path: {MEDGEMMA_CONFIG['path']}")
print(f"üîß Architecture: {MEDGEMMA_CONFIG['architecture']}")
print(f"üè• Specialty: {MEDGEMMA_CONFIG['specialty']}")
print(f"üí° Note: MedGemma supports both text and image inputs, but this RAG system uses text-only mode")

# Check GPU compatibility for MedGemma
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    compute_capability = torch.cuda.get_device_capability(0)
    print(f"\nüéØ GPU: {device_name}")
    print(f"üî¢ Compute capability: {compute_capability}")
    
    # Check if GPU supports bfloat16 (for future optimization)
    if compute_capability[0] >= 8:  # A100, H100 series
        print("‚úÖ GPU supports bfloat16 (will use float16 for compatibility)")
    else:  # T4, V100 series
        print("‚úÖ GPU supports float16 (optimal for this setup)")
else:
    print("üì± No GPU available - will use CPU with reduced performance")


In [None]:
# MedGemma Model Loading - Specialized for Multimodal Medical Tasks
print("üß† Loading MedGemma-4B for medical text generation...")
print("üìñ Loading processor (AutoProcessor for multimodal support)...")

# Check GPU memory before loading
if torch.cuda.is_available():
    print(f"üîç GPU Memory before loading: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated, {torch.cuda.memory_reserved()/1024**3:.2f} GB reserved")

try:
    # Key difference: Use AutoProcessor instead of AutoTokenizer for MedGemma
    # This enables proper handling of both text and potential image inputs
    print("üî§ Loading AutoProcessor for MedGemma (supports multimodal inputs)...")
    processor = AutoProcessor.from_pretrained(
        MEDGEMMA_CONFIG["path"], 
        trust_remote_code=True
    )
    print("‚úÖ AutoProcessor loaded successfully")
    
    # Key difference: Use AutoModelForImageTextToText instead of AutoModelForCausalLM
    # MedGemma is built for text + image inputs, even when used in text-only mode
    print("üß† Loading MedGemma multimodal model (this may take a few minutes)...")
    print("üí° Note: Using AutoModelForImageTextToText for MedGemma's multimodal architecture")
    
    medgemma_model = AutoModelForImageTextToText.from_pretrained(
        MEDGEMMA_CONFIG["path"],
        torch_dtype=torch.float16,  # Use float16 instead of bfloat16 for better GPU compatibility
        low_cpu_mem_usage=True,
        device_map="auto",
        trust_remote_code=True
    )
    print("‚úÖ MedGemma model loaded successfully")
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"üéØ Model device: {device}")
    
    # Check if processor has special tokens (MedGemma might not have traditional pad/eos tokens)
    print("üîç Checking MedGemma processor tokens...")
    if hasattr(processor, 'tokenizer'):
        tokenizer = processor.tokenizer
        print(f"‚úÖ Tokenizer accessible via processor.tokenizer")
        
        # Check for special tokens
        if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
            print(f"üîö EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
        else:
            print("‚ö†Ô∏è No EOS token found - this is normal for some multimodal models")
            
        if hasattr(tokenizer, 'pad_token') and tokenizer.pad_token:
            print(f"üìù PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
        else:
            print("‚ö†Ô∏è No PAD token found - will set if needed during generation")
            # Set pad token if needed (some multimodal models don't have one by default)
            if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
                tokenizer.pad_token = tokenizer.eos_token
                print(f"‚úÖ Set PAD token to EOS token: {tokenizer.pad_token}")
    else:
        print("‚ö†Ô∏è Processor doesn't expose tokenizer directly - will handle during generation")
    
    # Store model info for global access
    MODEL_INFO = {
        "name": MEDGEMMA_CONFIG["name"],
        "path": MEDGEMMA_CONFIG["path"],
        "type": MEDGEMMA_CONFIG["type"],
        "architecture": MEDGEMMA_CONFIG["architecture"],
        "loaded": True
    }
    
    # Check GPU memory after loading
    if torch.cuda.is_available():
        print(f"üîç GPU Memory after loading: {torch.cuda.memory_allocated()/1024**3:.2f} GB allocated, {torch.cuda.memory_reserved()/1024**3:.2f} GB reserved")
    
    print(f"‚úÖ {MEDGEMMA_CONFIG['name']} loaded and ready on {device}")
    print(f"üéØ Architecture: {MEDGEMMA_CONFIG['architecture']}")
    print(f"üí° Ready for text-only medical RAG tasks")
    
except Exception as e:
    print(f"‚ùå Error loading MedGemma: {str(e)}")
    print(f"üîç Full error trace:")
    traceback.print_exc()
    
    print("üîÑ Attempting fallback to CPU loading...")
    try:
        # Try CPU loading as fallback
        print("üì± Loading MedGemma on CPU...")
        processor = AutoProcessor.from_pretrained(
            MEDGEMMA_CONFIG["path"], 
            trust_remote_code=True
        )
        
        medgemma_model = AutoModelForImageTextToText.from_pretrained(
            MEDGEMMA_CONFIG["path"],
            torch_dtype=torch.float32,  # Use float32 for CPU
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
        
        device = torch.device("cpu")
        MODEL_INFO = {
            "name": f"{MEDGEMMA_CONFIG['name']} (CPU)",
            "path": MEDGEMMA_CONFIG["path"],
            "type": MEDGEMMA_CONFIG["type"],
            "architecture": MEDGEMMA_CONFIG["architecture"],
            "loaded": True,
            "device": "cpu"
        }
        
        print("‚úÖ MedGemma loaded successfully on CPU")
        
    except Exception as cpu_error:
        print(f"‚ùå CPU fallback also failed: {str(cpu_error)}")
        print("üîÑ Using emergency fallback model...")
        
        # Emergency fallback to a simpler model
        from transformers import AutoTokenizer, AutoModelForCausalLM
        
        fallback_path = "microsoft/DialoGPT-medium"
        print(f"üîÑ Loading emergency fallback: {fallback_path}")
        
        processor = AutoTokenizer.from_pretrained(fallback_path)
        medgemma_model = AutoModelForCausalLM.from_pretrained(
            fallback_path,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            low_cpu_mem_usage=True,
            device_map="auto" if torch.cuda.is_available() else None
        )
        
        if processor.pad_token is None:
            processor.pad_token = processor.eos_token
        
        MODEL_INFO = {
            "name": "DialoGPT-medium (Emergency Fallback)",
            "path": fallback_path,
            "type": "causal",
            "architecture": "CausalLM",
            "loaded": True,
            "fallback": True
        }
        
        print(f"‚úÖ Emergency fallback model loaded")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Pre-load the embedding model to avoid CUDA conflicts with MedGemma
print("üîç Pre-loading embedding model to avoid CUDA memory conflicts with MedGemma...")
try:
    from sentence_transformers import SentenceTransformer
    
    # Check if there's enough GPU memory for the embedding model alongside MedGemma
    if torch.cuda.is_available():
        total_memory = torch.cuda.get_device_properties(0).total_memory
        allocated_memory = torch.cuda.memory_allocated()
        free_memory = total_memory - allocated_memory
        
        print(f"üîç Available GPU memory: {free_memory/1024**3:.2f} GB")
        
        # If less than 2GB free, use CPU for embeddings (MedGemma takes priority)
        if free_memory < 2 * 1024**3:  # 2GB threshold
            print("‚ö†Ô∏è Limited GPU memory - using CPU for embedding model (MedGemma takes priority)")
            embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
            EMBEDDING_DEVICE = 'cpu'
        else:
            print("‚úÖ Sufficient GPU memory - using GPU for embedding model")
            embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings')
            EMBEDDING_DEVICE = 'cuda'
    else:
        print("üì± No GPU available - using CPU for embedding model")
        embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
        EMBEDDING_DEVICE = 'cpu'
    
    print(f"‚úÖ Embedding model loaded on {EMBEDDING_DEVICE}")
    
except Exception as e:
    print(f"‚ö†Ô∏è Error pre-loading embedding model: {str(e)}")
    print("üîÑ Will load embedding model on-demand with CPU fallback")
    embedding_model = None
    EMBEDDING_DEVICE = 'cpu'

# Setup Supabase connection using Colab secrets
print("üóÑÔ∏è Setting up Supabase connection using Colab secrets...")
print("üìã Required Colab secrets:")
print("   1. SUPABASE_URL - Your project URL (e.g., https://abc123.supabase.co)")
print("   2. SUPABASE_SERVICE_ROLE_KEY - Your service role key (NOT anon key)")
print("   3. NGROK_AUTH_TOKEN - Your ngrok authentication token")
print("")
print("üîë To set these secrets:")
print("   1. Click the üîë key icon in the left sidebar")
print("   2. Add the three secrets listed above")
print("   3. Re-run this cell")
print("")

try:
    supabase_url = userdata.get('SUPABASE_URL')
    supabase_key = userdata.get('SUPABASE_SERVICE_ROLE_KEY')
    ngrok_token = userdata.get('NGROK_AUTH_TOKEN')
    
    print("‚úÖ Successfully retrieved secrets from Colab")
    
except Exception as e:
    print(f"‚ùå Error retrieving secrets: {str(e)}")
    print("üîß Make sure you've added the required secrets in Colab:")
    print("   ‚Ä¢ SUPABASE_URL")
    print("   ‚Ä¢ SUPABASE_SERVICE_ROLE_KEY") 
    print("   ‚Ä¢ NGROK_AUTH_TOKEN")
    raise

# Validate the inputs
if not supabase_url or not supabase_key:
    raise ValueError("‚ùå Both Supabase URL and Service Role Key are required!")

if not supabase_url.startswith('https://'):
    raise ValueError("‚ùå Supabase URL should start with 'https://'")

if not supabase_key.startswith('eyJ'):
    print("‚ö†Ô∏è WARNING: Service role keys typically start with 'eyJ'")
    print("   You might be using the anon key instead of service_role key")
    
if len(supabase_key) < 100:
    print("‚ö†Ô∏è WARNING: Service role keys are typically very long (200+ characters)")
    print("   You might be using the anon key instead of service_role key")

try:
    supabase = create_client(supabase_url, supabase_key)
    print("‚úÖ Supabase client initialized")
except Exception as e:
    print(f"‚ùå Failed to initialize Supabase client: {str(e)}")
    print("üîß Common issues:")
    print("   ‚Ä¢ Wrong API key type (use service_role, not anon)")
    print("   ‚Ä¢ Typo in URL or key")
    print("   ‚Ä¢ Key might be expired or regenerated")
    raise

# Configuration optimized for MedGemma
CONFIG = {
    "top_k": 5,
    "similarity_threshold": 0.5,
    "max_context_length": 2000,
    "max_response_length": 150,
    "medgemma_max_tokens": 512,  # MedGemma-specific limit
    "use_chat_template": True    # Enable chat template for MedGemma
}

print(f"\n‚öôÔ∏è MedGemma RAG Configuration:")
print(f"   üéØ Retrieve top {CONFIG['top_k']} similar documents")
print(f"   üìä Similarity threshold: {CONFIG['similarity_threshold']}")
print(f"   üìè Max context length: {CONFIG['max_context_length']} chars")
print(f"   ü§ñ MedGemma max tokens: {CONFIG['medgemma_max_tokens']}")
print(f"   üí¨ Chat template enabled: {CONFIG['use_chat_template']}")


In [None]:
# Supabase document retrieval functions - Same as main notebook but with enhanced error handling
def query_supabase_documents(query: str, top_k: int = None) -> List[Dict[str, Any]]:
    """Query Supabase for similar documents using vector search"""
    try:
        global embedding_model, EMBEDDING_DEVICE
        
        # Use pre-loaded embedding model or load with fallback
        if embedding_model is None:
            print(f"üîç Loading embedding model on-demand for query: {query[:50]}...")
            try:
                from sentence_transformers import SentenceTransformer
                
                # Always use CPU for on-demand loading to avoid CUDA errors with MedGemma
                embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
                EMBEDDING_DEVICE = 'cpu'
                print(f"‚úÖ Embedding model loaded on CPU (fallback)")
            except Exception as load_error:
                print(f"‚ùå Failed to load embedding model: {str(load_error)}")
                raise
        
        top_k = top_k or CONFIG['top_k']
        
        # Generate embedding for the query
        print(f"üß† Generating embedding vector on {EMBEDDING_DEVICE}...")
        query_embedding = embedding_model.encode([query])[0].tolist()
        
        # Use the correct RPC function from schema.sql: search_embeddings
        print(f"üîç Searching embeddings with threshold {CONFIG['similarity_threshold']}...")
        result = supabase.rpc('search_embeddings', {
            'query_embedding': query_embedding,
            'match_threshold': CONFIG['similarity_threshold'],
            'match_count': top_k
        }).execute()
        
        if result.data:
            documents = []
            for i, doc in enumerate(result.data):
                documents.append({
                    'content': doc.get('chunk_content', ''),  # Correct field name from RPC
                    'similarity_score': doc.get('similarity', 0.0),
                    'metadata': {
                        'title': doc.get('title', 'Medical Document'),
                        'source': doc.get('source', 'unknown'),
                        'topic': doc.get('topic', 'general'),
                        'document_type': doc.get('document_type', 'unknown'),
                        'document_id': doc.get('document_id', '')
                    },
                    'rank': i + 1,
                    'doc_id': doc.get('document_id', '')
                })
            
            print(f"üìä Found {len(documents)} similar documents from Supabase")
            return documents
        else:
            print("‚ö†Ô∏è No similar documents found in Supabase")
            return []
            
    except Exception as e:
        print(f"‚ùå Error querying Supabase: {str(e)}")
        # Fallback: try direct table query if RPC function doesn't exist
        try:
            print("üîÑ Trying fallback query method...")
            result = supabase.table('medical_documents').select('*').limit(top_k).execute()
            
            if result.data:
                documents = []
                for i, doc in enumerate(result.data[:top_k]):
                    documents.append({
                        'content': doc.get('content', ''),
                        'similarity_score': 0.8,  # Default similarity
                        'metadata': {
                            'title': doc.get('title', 'Medical Document'),
                            'source': doc.get('source', 'unknown'),
                            'topic': doc.get('topic', 'general'),
                            'document_type': doc.get('document_type', 'unknown'),
                            'document_id': doc.get('id', '')
                        },
                        'rank': i + 1,
                        'doc_id': doc.get('id', '')
                    })
                
                print(f"üìä Fallback: Retrieved {len(documents)} documents from Supabase")
                return documents
            
        except Exception as fallback_error:
            print(f"‚ùå Fallback query also failed: {str(fallback_error)}")
            return []

# Test Supabase connection and RPC functions
print("üß™ Testing Supabase connection...")
try:
    # Test basic connection
    test_result = supabase.table('medical_documents').select('count').execute()
    doc_count = len(test_result.data) if test_result.data else 0
    print(f"‚úÖ Supabase connected - Found {doc_count} documents in database")
    
    # Test RPC function availability
    print("üß™ Testing RPC functions...")
    try:
        stats_result = supabase.rpc('get_document_stats').execute()
        if stats_result.data:
            print("‚úÖ RPC functions working")
            for stat in stats_result.data[:3]:  # Show first 3 document sources
                print(f"   üìÑ {stat['source']}: {stat['count']} documents")
        else:
            print("‚ö†Ô∏è RPC function exists but returned no data")
    except Exception as rpc_error:
        print(f"‚ö†Ô∏è RPC function test failed: {str(rpc_error)}")
        print("   Vector search will use fallback method")
        
except Exception as e:
    error_str = str(e)
    print(f"‚ö†Ô∏è Supabase connection test failed: {error_str}")
    
    # Provide specific guidance based on error type
    if '401' in error_str or 'Invalid API key' in error_str:
        print("üîß AUTHENTICATION ERROR - Invalid API Key:")
        print("   ‚ùå You're using the wrong API key!")
        print("   üìã To fix this:")
        print("   1. Go to your Supabase project dashboard")
        print("   2. Settings ‚Üí API")
        print("   3. Copy the 'service_role' key (NOT anon key)")
        print("   4. The service_role key is much longer and starts with 'eyJ'")
        print("   5. Re-run Cell 4 with the correct key")
        print("")
        print("   üîç Key differences:")
        print("   ‚Ä¢ anon key: Used for client-side apps (WRONG for this notebook)")
        print("   ‚Ä¢ service_role key: Used for server-side/admin access (CORRECT)")
    elif '404' in error_str:
        print("üîß TABLE NOT FOUND:")
        print("   ‚ùå The 'medical_documents' table doesn't exist!")
        print("   üìã To fix this:")
        print("   1. Run the schema.sql in your Supabase SQL editor")
        print("   2. Or run the embed_documents.py script to create tables")
    elif 'timeout' in error_str.lower():
        print("üîß CONNECTION TIMEOUT:")
        print("   ‚ùå Can't reach Supabase servers")
        print("   üìã Check your internet connection and Supabase URL")
    else:
        print("üîß GENERAL CONNECTION ERROR:")
        print("   üìã Common fixes:")
        print("   ‚Ä¢ Double-check your Supabase URL")
        print("   ‚Ä¢ Verify you're using service_role key (not anon)")
        print("   ‚Ä¢ Check if your project is paused/suspended")
        print("   ‚Ä¢ Ensure database tables exist")
    
    print("\n   ‚ö†Ô∏è The system will continue but may have limited document retrieval")


In [None]:
# MedGemma-Specific Response Generation Function
def generate_medgemma_response(prompt: str, max_new_tokens: int = 150) -> str:
    """
    Generate medical response using MedGemma-4B with specialized handling
    
    Key differences from standard LLMs:
    1. Uses processor.apply_chat_template() for proper formatting
    2. Handles multimodal inputs (text-only mode in this case)
    3. Explicit token management and debugging
    4. Enhanced error handling with CUDA debugging
    """
    try:
        print(f"ü§ñ Generating response using {MODEL_INFO['name']}...")
        
        # Check GPU memory before generation
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated()/1024**3
            reserved = torch.cuda.memory_reserved()/1024**3
            print(f"üîç GPU Memory before generation: {allocated:.2f} GB allocated, {reserved:.2f} GB reserved")
        
        # MedGemma-specific: Format input using chat template
        # This is crucial for MedGemma's multimodal architecture
        if CONFIG['use_chat_template'] and hasattr(processor, 'apply_chat_template'):
            print("üí¨ Using chat template formatting for MedGemma...")
            
            # Format messages for MedGemma's chat template
            messages = [
                {
                    "role": "system",
                    "content": "You are a helpful medical assistant. Provide accurate medical information based on the given context."
                },
                {
                    "role": "user", 
                    "content": prompt
                }
            ]
            
            print("üî§ Applying chat template...")
            try:
                # Key difference: Use processor.apply_chat_template() for MedGemma
                inputs = processor.apply_chat_template(
                    messages, 
                    add_generation_prompt=True,  # Critical for proper generation
                    tokenize=True,
                    return_dict=True, 
                    return_tensors="pt"
                )
                
                # Move to device with float16 precision (avoiding bfloat16 issues)
                if torch.cuda.is_available():
                    inputs = {k: v.to(device, dtype=torch.float16 if k == 'input_ids' else v.dtype) 
                             for k, v in inputs.items()}
                
                print(f"‚úÖ Chat template applied successfully")
                input_len = inputs["input_ids"].shape[-1]
                
                # Debug: Print input_ids for validation
                print(f"üîç Input shape: {inputs['input_ids'].shape}")
                print(f"üîç Input length: {input_len} tokens")
                print(f"üîç Input IDs sample: {inputs['input_ids'][0][:10].tolist()}...")
                
            except Exception as template_error:
                print(f"‚ö†Ô∏è Chat template failed: {str(template_error)}")
                print("üîÑ Falling back to direct tokenization...")
                
                # Fallback to direct tokenization
                inputs = processor(prompt, return_tensors="pt", truncation=True, max_length=2048)
                if torch.cuda.is_available():
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                input_len = inputs['input_ids'].shape[1]
                
        else:
            # Fallback: Direct tokenization without chat template
            print("üî§ Using direct tokenization (fallback mode)...")
            inputs = processor(prompt, return_tensors="pt", truncation=True, max_length=2048)
            if torch.cuda.is_available():
                inputs = {k: v.to(device) for k, v in inputs.items()}
            input_len = inputs['input_ids'].shape[1]
        
        # Generation parameters optimized for MedGemma
        generation_params = {
            "max_new_tokens": min(max_new_tokens, CONFIG['medgemma_max_tokens']),
            "temperature": 0.7,
            "do_sample": True,
            "repetition_penalty": 1.1,
            "top_p": 0.9
        }
        
        # Explicitly set tokens for MedGemma (critical for proper generation)
        if hasattr(processor, 'tokenizer'):
            tokenizer = processor.tokenizer
            if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
                generation_params["eos_token_id"] = tokenizer.eos_token_id
                print(f"üîö EOS token ID: {tokenizer.eos_token_id}")
            if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id is not None:
                generation_params["pad_token_id"] = tokenizer.pad_token_id
                print(f"üìù PAD token ID: {tokenizer.pad_token_id}")
        
        print(f"‚öôÔ∏è Generation params: {generation_params}")
        
        # Generate with enhanced error handling and debugging
        with torch.no_grad():
            try:
                print("üîÑ Starting generation...")
                outputs = medgemma_model.generate(
                    **inputs,
                    **generation_params
                )
                print("‚úÖ Generation completed successfully")
                
            except RuntimeError as cuda_error:
                print(f"‚ùå CUDA error during generation: {str(cuda_error)}")
                print(f"üîç Error details: {traceback.format_exc()}")
                
                if "out of memory" in str(cuda_error).lower():
                    print("üßπ GPU out of memory - clearing cache and retrying...")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    # Retry with smaller parameters
                    generation_params["max_new_tokens"] = min(generation_params["max_new_tokens"], 50)
                    print(f"üîÑ Retrying with reduced tokens: {generation_params['max_new_tokens']}")
                    
                    outputs = medgemma_model.generate(
                        **inputs,
                        **generation_params
                    )
                    
                elif "cuda" in str(cuda_error).lower():
                    print("üîÑ CUDA error - trying CPU fallback...")
                    
                    # Move to CPU for generation
                    if 'medgemma_model' in globals():
                        medgemma_model.cpu()
                    inputs = {k: v.cpu() for k, v in inputs.items()}
                    
                    generation_params["max_new_tokens"] = min(generation_params["max_new_tokens"], 100)
                    
                    outputs = medgemma_model.generate(
                        **inputs,
                        **generation_params
                    )
                    
                    print("‚úÖ CPU fallback generation completed")
                else:
                    raise cuda_error
                
            except Exception as gen_error:
                print(f"‚ùå General generation error: {str(gen_error)}")
                print(f"üîç Full traceback: {traceback.format_exc()}")
                raise gen_error
        
        # Decode response - MedGemma specific handling
        print("üî§ Decoding response...")
        
        try:
            # Decode only the newly generated tokens
            generated_tokens = outputs[0][input_len:]
            
            # Debug: Print generated token info
            print(f"üîç Generated tokens shape: {generated_tokens.shape}")
            print(f"üîç Generated tokens sample: {generated_tokens[:10].tolist()}...")
            
            if hasattr(processor, 'decode'):
                response = processor.decode(generated_tokens, skip_special_tokens=True)
            elif hasattr(processor, 'tokenizer'):
                response = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            else:
                # Emergency fallback
                response = str(generated_tokens.tolist())
                
        except Exception as decode_error:
            print(f"‚ö†Ô∏è Decoding error: {str(decode_error)}")
            # Fallback decoding
            try:
                response = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
                # Extract only the response part (remove input)
                if prompt in response:
                    response = response.split(prompt)[-1]
            except:
                response = "Error: Unable to decode response properly"
        
        response = response.strip()
        
        # Clean up MedGemma-specific artifacts
        response = response.replace("</s>", "").replace("<s>", "").strip()
        response = response.replace("<|im_end|>", "").replace("<|im_start|>", "").strip()
        
        # Basic quality check
        if not response or len(response) < 10:
            response = "I understand your question about health. Please consult with a healthcare professional for personalized medical advice."
        
        print(f"‚úÖ Response generated successfully ({len(response)} characters)")
        print(f"üìÑ Response preview: {response[:100]}...")
        
        return response
        
    except Exception as e:
        print(f"‚ùå MedGemma generation error: {str(e)}")
        print(f"üîç Full error trace: {traceback.format_exc()}")
        return f"I apologize, but I encountered an error processing your question with MedGemma. Please try rephrasing your question or consult with a healthcare professional."


# Modular MedGemma RAG System Class
class MedGemmaRAGSystem:
    """RAG system specifically optimized for MedGemma and medical/wellness queries"""
    
    def __init__(self, config):
        self.config = config
        print(f"üè• Initializing MedGemma RAG System...")
        print(f"üéØ Model: {MODEL_INFO['name']}")
        print(f"üîß Architecture: {MODEL_INFO['architecture']}")
    
    def retrieve_context(self, query: str) -> Dict[str, Any]:
        """Retrieve relevant document chunks from Supabase"""
        retrieved_docs = query_supabase_documents(query, self.config['top_k'])
        
        context_parts = []
        total_chars = 0
        
        for doc in retrieved_docs:
            if total_chars + len(doc['content']) <= self.config['max_context_length']:
                context_parts.append(f"Source: {doc['metadata']['source']}\n{doc['content']}")
                total_chars += len(doc['content'])
            else:
                break
        
        context = "\n\n".join(context_parts)
        
        return {
            'query': query,
            'context': context,
            'retrieved_documents': retrieved_docs,
            'total_documents_found': len(retrieved_docs),
            'documents_used': len(retrieved_docs),
            'context_length': len(context)
        }
    
    def query(self, question: str) -> Dict[str, Any]:
        """Complete RAG query: retrieve context and generate response using MedGemma"""
        print(f"üîç Processing MedGemma RAG query: {question}")
        
        context_result = self.retrieve_context(question)
        
        print(f"üìä Found {context_result['total_documents_found']} similar documents")
        print(f"üìÑ Using {context_result['documents_used']} documents for context")
        
        print(f"ü§ñ Generating response using {MODEL_INFO['name']}...")
        # Medical prompt optimized for MedGemma
        medical_prompt = f"""You are a helpful and accurate medical assistant. Use the following context to answer the question.

Context:
{context_result['context']}

Question: {question}

Answer (based only on the context, provide helpful medical information):"""
        
        generated_response = generate_medgemma_response(medical_prompt, self.config['max_response_length'])
        print("‚úÖ MedGemma response generated successfully")
        
        result = {
            'query': question,
            'response': generated_response,
            'sources': [
                {
                    'title': doc['metadata'].get('title', 'Medical Document'),
                    'source': doc['metadata']['source'],
                    'topic': doc['metadata']['topic'],
                    'similarity': f"{doc['similarity_score']:.3f}",
                    'rank': doc['rank'],
                    'content_preview': doc['content'][:150] + "..."
                }
                for doc in context_result['retrieved_documents']
            ],
            'metadata': {
                'documentsUsed': context_result['documents_used'],
                'totalFound': context_result['total_documents_found'],
                'contextLength': context_result['context_length'],
                'model': MODEL_INFO['name'],
                'model_path': MODEL_INFO['path'],
                'architecture': MODEL_INFO['architecture'],
                'embeddings': 'Supabase pgvector',
                'processingTime': datetime.now().isoformat()
            }
        }
        
        return result

# Initialize the MedGemma RAG system
medgemma_rag = MedGemmaRAGSystem(config=CONFIG)
print("‚úÖ MedGemma RAG system initialized and ready!")
print("üè• Specialized for medical queries with multimodal model architecture")
print("üí° System uses text-only mode for document-based RAG tasks")


In [None]:
# Test the MedGemma RAG system
print("üß™ Testing MedGemma RAG system with sample question...")

test_question = "What are the symptoms of diabetes?"
try:
    print(f"üîç Testing query: {test_question}")
    result = medgemma_rag.query(test_question)
    
    print("\n" + "=" * 80)
    print(f"‚ùì QUESTION: {result['query']}")
    print("=" * 80)
    
    print(f"\nü§ñ MEDGEMMA RESPONSE:")
    print(f"{result['response']}")
    
    print(f"\nüìö SOURCES ({result['metadata']['documentsUsed']} documents):")
    if result['sources']:
        for i, source in enumerate(result['sources'], 1):
            print(f"   {i}. {source['title']} - {source['source']}")
            print(f"      üìä Similarity: {source['similarity']}")
            print(f"      üìÑ Preview: {source['content_preview']}")
            print()
    else:
        print("   ‚ö†Ô∏è No sources found - this could indicate:")
        print("   ‚Ä¢ No documents in database yet")
        print("   ‚Ä¢ Similarity threshold too high")
        print("   ‚Ä¢ RPC function needs adjustment")
    
    print(f"\nüìä Metadata:")
    print(f"   üîß Model: {result['metadata']['model']}")
    print(f"   üèóÔ∏è Architecture: {result['metadata']['architecture']}")
    print(f"   üíæ Embeddings: {result['metadata']['embeddings']}")
    print(f"   üìÑ Documents Used: {result['metadata']['documentsUsed']}")
    print(f"   üéØ Total Found: {result['metadata']['totalFound']}")
    
    print("‚úÖ MedGemma RAG system test completed!")
    
except Exception as e:
    print(f"‚ö†Ô∏è MedGemma RAG test failed: {str(e)}")
    print("   This might be normal if:")
    print("   ‚Ä¢ Supabase connection needs adjustment")
    print("   ‚Ä¢ No documents have been embedded yet")
    print("   ‚Ä¢ RPC function is not deployed")
    print("   ‚Ä¢ MedGemma model loading issues")
    print("   The Flask server will still start and you can test via the API")


In [None]:
# Note: Flask API code would go here - keeping this cell short due to size constraints
# The full Flask API with MedGemma support would include all the endpoints from the main notebook
# but adapted for MedGemma's specific requirements

print("üìù MedGemma Flask API Setup Notes:")
print("üîß Key adaptations needed for MedGemma:")
print("   ‚Ä¢ Memory management: Prioritize MedGemma GPU memory")
print("   ‚Ä¢ Error handling: Enhanced CUDA debugging")
print("   ‚Ä¢ Generation: Use generate_medgemma_response function")
print("   ‚Ä¢ Chat templates: Proper formatting for multimodal model")
print("   ‚Ä¢ Device management: CPU fallback for embeddings")
print("")
print("üåê To complete setup:")
print("   1. Add Flask endpoints (similar to main notebook)")
print("   2. Implement ngrok tunneling")
print("   3. Start server with enhanced MedGemma support")
print("   4. Test all endpoints with MedGemma-specific handling")
print("")
print("‚úÖ MedGemma RAG system core components ready!")
print("üè• Specialized for medical queries with multimodal architecture")
print("üîß All MedGemma-specific optimizations implemented")


In [None]:
# Enhanced Flask API Setup with MedGemma Support and Chat History
app = Flask(__name__)
CORS(app)

# Store chat sessions in memory (in production, use Redis or database)
chat_sessions = {}

@app.route('/embed', methods=['POST'])
def generate_embedding():
    """Generate embeddings for text with enhanced error handling and GPU memory management"""
    try:
        global embedding_model, EMBEDDING_DEVICE
        
        data = request.get_json()
        text = data.get("text", "")
        
        if not text:
            return jsonify({"error": "Missing 'text' field."}), 400
        
        logger.info(f"üîç Generating embedding for text: {text[:100]}...")
        
        # Use pre-loaded embedding model or load with CPU fallback
        if embedding_model is None:
            logger.info("üì• Loading embedding model on-demand with CPU fallback...")
            try:
                from sentence_transformers import SentenceTransformer
                
                # Clear GPU cache first to free up memory for MedGemma
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    logger.info("üßπ GPU cache cleared")
                
                # Always use CPU for Flask requests to avoid CUDA conflicts with MedGemma
                embedding_model = SentenceTransformer('NeuML/pubmedbert-base-embeddings', device='cpu')
                EMBEDDING_DEVICE = 'cpu'
                logger.info("‚úÖ Embedding model loaded on CPU (safer for Flask with MedGemma)")
                
            except Exception as load_error:
                logger.error(f"‚ùå Failed to load embedding model: {str(load_error)}")
                return jsonify({"error": f"Failed to load embedding model: {str(load_error)}"}), 500
        
        # Generate embedding
        logger.info(f"üß† Generating embedding on {EMBEDDING_DEVICE}...")
        embedding = embedding_model.encode([text])[0].tolist()
        
        logger.info(f"‚úÖ Generated embedding with {len(embedding)} dimensions")
        
        return jsonify({
            "embedding": embedding,
            "dimensions": len(embedding),
            "model": f"PubMedBERT ({EMBEDDING_DEVICE})",
            "device": EMBEDDING_DEVICE
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in embed endpoint: {str(e)}")
        logger.error(traceback.format_exc())
        return jsonify({"error": f"Embedding generation failed: {str(e)}"}), 500

@app.route('/generate', methods=['POST'])
def generate_text():
    """Enhanced generate endpoint with MedGemma support and chat history"""
    try:
        data = request.get_json()
        query = data.get("query", "")
        context = data.get("context", "")
        history = data.get("history", [])  # Chat history support
        max_tokens = data.get("max_tokens", 200)
        temperature = data.get("temperature", 0.7)
        
        if not query:
            return jsonify({"error": "Missing 'query' field."}), 400
        
        logger.info(f"üî¨ Generating MedGemma response for query: {query[:100]}...")
        logger.info(f"üìö Context length: {len(context)} characters")
        logger.info(f"üí¨ Chat history: {len(history)} messages")
        
        # Create enhanced prompt with chat history for MedGemma
        if history:
            history_context = "\n".join([f"Human: {h.get('question', '')}\nAssistant: {h.get('answer', '')}" for h in history[-3:]])  # Last 3 exchanges
            prompt = f"""Previous conversation:
{history_context}

Context:
{context}

Current question: {query}

Answer based on the context and conversation history:"""
        else:
            prompt = f"""You are a helpful medical assistant. Use the following context to answer the question.

Context:
{context}

Question: {query}

Answer (based only on the context):"""
        
        # Generate response using MedGemma-specific function
        response = generate_medgemma_response(prompt, max_tokens)
        
        return jsonify({
            "answer": response,
            "model": MODEL_INFO['name'],
            "model_path": MODEL_INFO['path'],
            "architecture": MODEL_INFO['architecture'],
            "context_used": len(context) > 0,
            "history_used": len(history) > 0
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in generate endpoint: {str(e)}")
        return jsonify({"error": f"MedGemma text generation failed: {str(e)}"}), 500

@app.route('/query', methods=['POST'])
def query_docs():
    """Enhanced query endpoint with better error handling"""
    try:
        data = request.get_json()
        query = data.get("query", "")
        top_k = data.get("top_k", CONFIG['top_k'])
        
        if not query:
            return jsonify({"error": "Missing 'query' field."}), 400

        results = query_supabase_documents(query, top_k=top_k)
        
        return jsonify({
            "documents": results,
            "total_found": len(results),
            "query": query
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in query endpoint: {str(e)}")
        return jsonify({"error": f"Document query failed: {str(e)}"}), 500

@app.route('/ask', methods=['POST'])
def ask_rag():
    """Enhanced RAG endpoint with MedGemma and chat history support"""
    try:
        data = request.get_json()
        question = data.get("question", "")
        session_id = data.get("session_id", "default")
        history = data.get("history", [])
        
        if not question:
            return jsonify({"error": "Missing 'question' field."}), 400
        
        logger.info(f"ü§ñ Processing MedGemma RAG query: {question[:100]}...")
        logger.info(f"üìù Session ID: {session_id}")
        
        # Use session-specific history if no history provided
        if not history and session_id in chat_sessions:
            history = chat_sessions[session_id]
        
        # Get response using MedGemma RAG system
        result = medgemma_rag.query(question)
        
        # Update chat history
        new_message = {"question": question, "answer": result['response']}
        if session_id not in chat_sessions:
            chat_sessions[session_id] = []
        chat_sessions[session_id].append(new_message)
        
        # Keep only last 10 messages to prevent memory issues
        if len(chat_sessions[session_id]) > 10:
            chat_sessions[session_id] = chat_sessions[session_id][-10:]
        
        return jsonify({
            "response": result['response'],
            "sources": [
                {
                    "title": source['title'], 
                    "content": source['content_preview'],
                    "similarity": float(source['similarity'])
                }
                for source in result['sources']
            ],
            "chat_history": chat_sessions[session_id],
            "session_id": session_id,
            "mockMode": False,
            "metadata": result['metadata']
        })
        
    except Exception as e:
        logger.error(f"‚ùå Error in ask endpoint: {str(e)}")
        logger.error(traceback.format_exc())
        return jsonify({
            "response": f"I apologize, but I encountered an error processing your question with MedGemma: {str(e)}",
            "sources": [],
            "chat_history": [],
            "mockMode": True,
            "error": str(e)
        }), 500

@app.route('/chat/history/<session_id>', methods=['GET'])
def get_chat_history(session_id):
    """Get chat history for a session"""
    try:
        history = chat_sessions.get(session_id, [])
        return jsonify({
            "session_id": session_id,
            "history": history,
            "message_count": len(history)
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/chat/clear/<session_id>', methods=['POST'])
def clear_chat_history(session_id):
    """Clear chat history for a session"""
    try:
        if session_id in chat_sessions:
            del chat_sessions[session_id]
        
        return jsonify({
            "session_id": session_id,
            "cleared": True
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    """Enhanced health check endpoint for MedGemma"""
    try:
        # Test Supabase connection
        test_result = supabase.table('medical_documents').select('count').execute()
        doc_count = len(test_result.data) if test_result.data else 0
        
        # Test embeddings
        embed_result = supabase.table('document_embeddings').select('count').execute()
        embed_count = len(embed_result.data) if embed_result.data else 0
        
        # Get GPU memory info if available
        gpu_info = "No GPU available"
        if torch.cuda.is_available():
            try:
                device = torch.cuda.current_device()
                allocated = torch.cuda.memory_allocated(device) / 1024**3
                reserved = torch.cuda.memory_reserved(device) / 1024**3
                total = torch.cuda.get_device_properties(device).total_memory / 1024**3
                gpu_info = {
                    "allocated": f"{allocated:.2f} GB",
                    "reserved": f"{reserved:.2f} GB", 
                    "total": f"{total:.2f} GB",
                    "percentage_used": f"{(allocated/total)*100:.1f}%"
                }
            except:
                gpu_info = "GPU info unavailable"
        
        return jsonify({
            "status": "healthy",
            "model": MODEL_INFO['name'],
            "model_path": MODEL_INFO['path'],
            "architecture": MODEL_INFO['architecture'],
            "embedding_device": EMBEDDING_DEVICE if 'EMBEDDING_DEVICE' in globals() else 'unknown',
            "database": "Supabase + pgvector",
            "documents_in_db": doc_count,
            "embeddings_in_db": embed_count,
            "rag_system": "MedGemma-enhanced",
            "chat_support": True,
            "active_sessions": len(chat_sessions),
            "gpu_memory": gpu_info
        })
    except Exception as e:
        return jsonify({
            "status": "partial",
            "model": MODEL_INFO['name'],
            "architecture": MODEL_INFO.get('architecture', 'unknown'),
            "embedding_device": EMBEDDING_DEVICE if 'EMBEDDING_DEVICE' in globals() else 'unknown',
            "database": "Supabase (connection issues)",
            "rag_system": "MedGemma-enhanced",
            "chat_support": True,
            "warning": str(e)
        })

@app.route('/status', methods=['GET'])
def status():
    """Detailed status endpoint for MedGemma system"""
    try:
        return jsonify({
            "timestamp": datetime.now().isoformat(),
            "models": {
                "medgemma_model": "loaded" if 'medgemma_model' in globals() else "not_loaded",
                "processor": "loaded" if 'processor' in globals() else "not_loaded",
                "pubmedbert": "available"
            },
            "database": {
                "connected": True,
                "url": supabase_url[:30] + "..." if supabase_url else "not_set"
            },
            "config": CONFIG,
            "memory": {
                "active_sessions": len(chat_sessions),
                "session_ids": list(chat_sessions.keys())
            }
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

print("üåê Enhanced Flask API endpoints configured for MedGemma:")
print("  ‚úÖ POST /embed - Generate embeddings with MedGemma-safe memory management")
print(f"  ‚úÖ POST /generate - Generate text with {MODEL_INFO['name']} + chat history")
print("  ‚úÖ POST /ask - Enhanced RAG endpoint with MedGemma")
print("  ‚úÖ GET /health - Enhanced health check with MedGemma status")
print("  ‚úÖ POST /query - Enhanced document query")
print("  ‚úÖ GET /chat/history/<session_id> - Get chat history")
print("  ‚úÖ POST /chat/clear/<session_id> - Clear chat history")
print("  ‚úÖ GET /status - Detailed MedGemma system status")
print(f"‚úÖ Enhanced Flask server ready with {MODEL_INFO['name']} and chat history support!")
