In [4]:
!pip install torch gradio numpy pillow
!pip install docling marker-pdf
!pip install transformers sentence-transformers accelerate
!pip install llama-index llama-index-embeddings-huggingface llama-index-vector-stores-faiss llama-index-core
!pip install faiss-cpu
!pip install python-dotenv tqdm pandas
!pip install -U bitsandbytes
!pip install pymupdfsudo apt-get install -y python3-pip python3-venv build-essential libvips-dev
!pip install pyvips
!sudo apt-get install -y libvips-dev
!pip install pymupdf


Usage:   
  pip3 install [options] <requirement specifier> [package-index-options] ...
  pip3 install [options] -r <requirements file> [package-index-options] ...
  pip3 install [options] [-e] <vcs project url> ...
  pip3 install [options] [-e] <local project path> ...
  pip3 install [options] <archive url/path> ...

no such option: -y
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
libvips-dev is already the newest version (8.12.1-1build1).
0 upgraded, 0 newly installed, 0 to remove and 83 not upgraded.
Collecting pymupdf
  Downloading pymupdf-1.26.0-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (3.4 kB)
Downloading pymupdf-1.26.0-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (24.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.1/24.1 MB[0m [31m79.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: pymupdf
Successfully installed pymupdf-1.26.0


In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import os
import sys
import json
import time
import torch
import faiss
import gradio as gr
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import numpy as np
from PIL import Image
import re
import logging
import gc
import psutil
import fitz  # PyMuPDF


# Set up enhanced logging for terminal output
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),  # Ensure logs go to stdout
        logging.FileHandler('rag_system.log')  # Also log to file
    ]
)
logger = logging.getLogger(__name__)


# Force flush for immediate output
def log_and_print(message, level="info"):
    """Log message and print to ensure visibility"""
    if level == "info":
        logger.info(message)
    elif level == "warning":
        logger.warning(message)
    elif level == "error":
        logger.error(message)
    print(f"[{datetime.now().strftime('%H:%M:%S')}] {message}")
    sys.stdout.flush()


# Create output directories
OUTPUT_DIR = Path("rag_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)
for subdir in ["markdown", "captions", "documents", "chunks", "retrievals", "prompts", "pdf_images"]:
    (OUTPUT_DIR / subdir).mkdir(exist_ok=True)


# Try importing required libraries
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    TRANSFORMERS_AVAILABLE = True
    log_and_print("✓ Transformers available")
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    log_and_print("✗ Transformers not available. Install with: pip install transformers", "warning")


try:
    from marker.converters.pdf import PdfConverter
    from marker.models import create_model_dict
    from marker.output import text_from_rendered
    from marker.config.parser import ConfigParser
    MARKER_AVAILABLE = True
    log_and_print("✓ Marker available")
except ImportError:
    MARKER_AVAILABLE = False
    log_and_print("✗ Marker not available. Install with: pip install marker-pdf", "warning")


try:
    from docling.document_converter import DocumentConverter
    DOCLING_AVAILABLE = True
    log_and_print("✓ Docling available")
except ImportError:
    DOCLING_AVAILABLE = False
    log_and_print("✗ Docling not available. Install with: pip install docling", "warning")


try:
    from llama_index.core import Document, VectorStoreIndex, StorageContext
    from llama_index.core.node_parser import SemanticSplitterNodeParser
    from llama_index.embeddings.huggingface import HuggingFaceEmbedding
    from llama_index.vector_stores.faiss import FaissVectorStore
    from llama_index.core.retrievers import VectorIndexRetriever
    from llama_index.core.query_engine import RetrieverQueryEngine
    LLAMAINDEX_AVAILABLE = True
    log_and_print("✓ LlamaIndex available")
except ImportError:
    LLAMAINDEX_AVAILABLE = False
    log_and_print("✗ LlamaIndex not available. Install required packages", "warning")


try:
    from sentence_transformers import SentenceTransformer
    SENTENCE_TRANSFORMERS_AVAILABLE = True
    log_and_print("✓ Sentence transformers available")
except ImportError:
    SENTENCE_TRANSFORMERS_AVAILABLE = False
    log_and_print("✗ Sentence transformers not available. Install with: pip install sentence-transformers", "warning")


class MemoryMonitor:
    """Monitor GPU and system memory usage"""

    @staticmethod
    def get_gpu_memory():
        """Get GPU memory info"""
        if not torch.cuda.is_available():
            return {"available": False}

        try:
            device = torch.cuda.current_device()
            total = torch.cuda.get_device_properties(device).total_memory / 1024**3  # GB
            allocated = torch.cuda.memory_allocated(device) / 1024**3  # GB
            cached = torch.cuda.memory_reserved(device) / 1024**3  # GB
            free = total - allocated

            return {
                "available": True,
                "total": total,
                "allocated": allocated,
                "cached": cached,
                "free": free,
                "usage_percent": (allocated / total) * 100
            }
        except Exception as e:
            log_and_print(f"Error getting GPU memory: {e}", "error")
            return {"available": False, "error": str(e)}

    @staticmethod
    def get_system_memory():
        """Get system RAM info"""
        try:
            memory = psutil.virtual_memory()
            return {
                "total": memory.total / 1024**3,  # GB
                "available": memory.available / 1024**3,  # GB
                "used": memory.used / 1024**3,  # GB
                "usage_percent": memory.percent
            }
        except Exception as e:
            log_and_print(f"Error getting system memory: {e}", "error")
            return {"error": str(e)}

    @staticmethod
    def clear_gpu_cache():
        """Clear GPU cache"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
            log_and_print("GPU cache cleared")

    @staticmethod
    def get_memory_summary():
        """Get formatted memory summary"""
        gpu_mem = MemoryMonitor.get_gpu_memory()
        sys_mem = MemoryMonitor.get_system_memory()

        summary = "## Memory Usage\n\n"

        if gpu_mem.get("available"):
            summary += f"**GPU Memory:**\n"
            summary += f"- Total: {gpu_mem['total']:.2f} GB\n"
            summary += f"- Allocated: {gpu_mem['allocated']:.2f} GB ({gpu_mem['usage_percent']:.1f}%)\n"
            summary += f"- Cached: {gpu_mem['cached']:.2f} GB\n"
            summary += f"- Free: {gpu_mem['free']:.2f} GB\n\n"
        else:
            summary += "**GPU Memory:** Not available\n\n"

        if "error" not in sys_mem:
            summary += f"**System RAM:**\n"
            summary += f"- Total: {sys_mem['total']:.2f} GB\n"
            summary += f"- Used: {sys_mem['used']:.2f} GB ({sys_mem['usage_percent']:.1f}%)\n"
            summary += f"- Available: {sys_mem['available']:.2f} GB\n\n"
        else:
            summary += "**System RAM:** Error reading\n\n"

        return summary


class PDFExtractor:
    """Handles PDF extraction using Docling or Marker with GPU support"""

    def __init__(self, force_cpu: bool = False):
        self.docling_converter = None
        self.marker_converter = None
        self.marker_llm_converter = None
        self.force_cpu = force_cpu
        self.device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
        log_and_print(f"PDFExtractor initialized with device: {self.device}")

    def clear_models(self):
        """Clear loaded models to free memory"""
        self.docling_converter = None
        self.marker_converter = None
        self.marker_llm_converter = None
        gc.collect()
        MemoryMonitor.clear_gpu_cache()
        log_and_print("PDF extractor models cleared")

    def extract_with_docling(self, pdf_path: str) -> Tuple[str, float]:
        """Extract text using Docling with GPU support"""
        if not DOCLING_AVAILABLE:
            raise ImportError("Docling not available")

        log_and_print(f"Extracting PDF with Docling: {os.path.basename(pdf_path)} on {self.device}")
        start_time = time.time()

        if self.docling_converter is None:
            log_and_print(f"Loading Docling converter on {self.device}...")
            try:
                self.docling_converter = DocumentConverter()
                log_and_print("Docling converter loaded successfully")
            except Exception as e:
                log_and_print(f"Error initializing Docling: {e}", "error")
                raise

        log_and_print("Starting Docling conversion...")
        result = self.docling_converter.convert(pdf_path)
        markdown_text = result.document.export_to_markdown()

        extraction_time = time.time() - start_time
        log_and_print(f"Docling extraction completed in {extraction_time:.2f}s")
        
        self.clear_models()

        return markdown_text, extraction_time

    def extract_with_marker(self, pdf_path: str, use_llm: bool = False) -> Tuple[str, float]:
        """Extract text using Marker with GPU support"""
        if not MARKER_AVAILABLE:
            raise ImportError("Marker not available")

        log_and_print(f"Extracting PDF with Marker (LLM: {use_llm}): {os.path.basename(pdf_path)} on {self.device}")
        start_time = time.time()
        
        if self.marker_converter is None:
            log_and_print(f"Loading Marker converter on {self.device}...")
            try:
                model_dict = create_model_dict()
                self.marker_converter = PdfConverter(artifact_dict=model_dict)
                log_and_print("Marker converter loaded successfully")
            except Exception as e:
                log_and_print(f"Error loading Marker converter: {e}", "error")
                raise

        log_and_print("Starting Marker conversion...")
        rendered = self.marker_converter(pdf_path)
        text, _, _ = text_from_rendered(rendered)

        extraction_time = time.time() - start_time
        log_and_print(f"Marker extraction completed in {extraction_time:.2f}s")

        self.clear_models()
        
        return text, extraction_time


class ImageCaptioner:
    """Handles image captioning using Moondream"""

    def __init__(self, force_cpu: bool = False):
        self.model = None
        self.tokenizer = None
        self.force_cpu = force_cpu
        self.device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else "cpu")

    def clear_model(self):
        """Clear loaded model to free memory"""
        if self.model is not None:
            del self.model
            self.model = None
        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None
        gc.collect()
        MemoryMonitor.clear_gpu_cache()
        log_and_print("Image captioning model cleared")

    def load_model(self):
        """Load Moondream model with explicit 4-bit quantization."""
        if self.model is None and TRANSFORMERS_AVAILABLE:
            # FIX: Corrected the model identifier from 'moondream/moondream2'
            model_id = "vikhyatk/moondream2"
            log_and_print(f"Loading Moondream model '{model_id}' on {self.device}...")
            try:
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16
                )
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_id,
                    trust_remote_code=True,
                    quantization_config=quantization_config if self.device == "cuda" else None,
                    device_map="auto"
                )
                self.tokenizer = AutoTokenizer.from_pretrained(model_id)
                log_and_print(f"Moondream model loaded successfully on {self.device}")
            except Exception as e:
                log_and_print(f"Failed to load Moondream model: {e}", "error")
                self.model = None
                self.tokenizer = None

    def extract_images_from_pdf(self, pdf_path: str, output_dir: Path) -> List[Dict[str, str]]:
        """Extract images from a PDF file."""
        images = []
        try:
            doc = fitz.open(pdf_path)
            for page_num in range(len(doc)):
                for img_index, img in enumerate(doc.get_page_images(page_num)):
                    xref = img[0]
                    base_image = doc.extract_image(xref)
                    image_bytes = base_image["image"]
                    image_ext = base_image["ext"]
                    image_path = output_dir / f"page{page_num+1}_img{img_index+1}.{image_ext}"
                    with open(image_path, "wb") as f:
                        f.write(image_bytes)
                    images.append({
                        "index": len(images),
                        "path": str(image_path),
                        "page": page_num + 1,
                    })
            log_and_print(f"Extracted {len(images)} images from {os.path.basename(pdf_path)}")
        except Exception as e:
            log_and_print(f"Error extracting images from PDF: {e}", "error")
        return images


    def caption_image(self, image_path: str) -> str:
        """Generate caption for an image"""
        if self.model is None or self.tokenizer is None:
            self.load_model()
        
        if self.model is None:
            return "Image captioning model not available"

        try:
            image = Image.open(image_path)
            enc_image = self.model.encode_image(image)
            caption = self.model.answer_question(enc_image, "Describe this image in detail.")
            return caption
        except Exception as e:
            log_and_print(f"Error captioning image {image_path}: {e}", "error")
            return f"Error processing image: {str(e)}"

    def process_document_images(self, pdf_path: str) -> Tuple[str, List[Dict]]:
        """Process all images in a PDF, generate captions, and return a markdown string of captions."""
        image_output_dir = OUTPUT_DIR / "pdf_images" / datetime.now().strftime("%Y%m%d_%H%M%S")
        image_output_dir.mkdir(parents=True, exist_ok=True)
        
        images = self.extract_images_from_pdf(pdf_path, image_output_dir)
        if not images:
            log_and_print("No images found in document")
            return "", []

        log_and_print(f"Found {len(images)} images in document")
        self.load_model()

        captions_data = []
        captions_markdown = "\n\n## Image Captions\n\n"

        for img_info in images:
            caption = self.caption_image(img_info["path"])
            caption_entry = {
                "index": img_info["index"],
                "page": img_info["page"],
                "caption": caption,
                "path": img_info["path"]
            }
            captions_data.append(caption_entry)
            captions_markdown += f"**Image on Page {img_info['page']} (path: {img_info['path']}):**\n{caption}\n\n"
        
        self.clear_model()
        return captions_markdown, captions_data


class QwenEmbedding(HuggingFaceEmbedding):
    """Custom embedding class for Qwen-Embedding"""
    def __init__(self, **kwargs):
        super().__init__(
            model_name="Qwen/Qwen3-Embedding-0.6B",
            **kwargs
        )
        self.query_instruction = "Given a web search query, retrieve relevant passages that answer the query"

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get query embedding with instruction"""
        instructed_query = f"Instruct: {self.query_instruction}\nQuery: {query}"
        return super()._get_query_embedding(instructed_query)


class PDFChatRAG:
    """Main RAG pipeline with memory management"""

    def __init__(self, force_cpu: bool = False):
        self.extractor = PDFExtractor(force_cpu=force_cpu)
        self.captioner = ImageCaptioner(force_cpu=force_cpu)
        self.embed_model = None
        self.llm_model = None
        self.llm_tokenizer = None
        self.index = None
        self.current_session_id = None
        self.force_cpu = force_cpu
        self.processed_pdf_path = None
        self.debug_info = {}
        log_and_print(f"PDFChatRAG initialized with force_cpu: {force_cpu}")
    
    def update_force_cpu(self, force_cpu: bool):
        """Update force_cpu setting for all components"""
        self.force_cpu = force_cpu
        self.extractor.force_cpu = force_cpu
        self.extractor.device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
        self.captioner.force_cpu = force_cpu
        self.captioner.device = "cpu" if force_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
        log_and_print(f"Force CPU mode updated to: {force_cpu}")

    def clear_all_models(self):
        """Clear all loaded models to free memory"""
        log_and_print("Clearing all models to free memory...")
        self.extractor.clear_models()
        self.captioner.clear_model()
        if self.embed_model:
            del self.embed_model
            self.embed_model = None
        if self.llm_model:
            del self.llm_model
            self.llm_model = None
        if self.llm_tokenizer:
            del self.llm_tokenizer
            self.llm_tokenizer = None
        gc.collect()
        MemoryMonitor.clear_gpu_cache()
        log_and_print("All models cleared")

    def initialize_embedding_model(self):
        """Initialize embedding model with memory check"""
        if self.embed_model is None and SENTENCE_TRANSFORMERS_AVAILABLE:
            log_and_print("Loading Qwen embedding model...")
            try:
                self.embed_model = QwenEmbedding()
                log_and_print("Embedding model loaded successfully")
            except Exception as e:
                log_and_print(f"Failed to load embedding model: {e}", "error")
                self.embed_model = None

    def initialize_llm_model(self):
        """Initialize LLM model with memory check"""
        if self.llm_model is None and TRANSFORMERS_AVAILABLE:
            model_name = "unsloth/Qwen3-4B-unsloth-bnb-4bit"
            log_and_print(f"Loading LLM model: {model_name}...")
            device = "cpu" if self.force_cpu else ("cuda" if torch.cuda.is_available() else "cpu")
            try:
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
                )
                self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.llm_model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    quantization_config=quantization_config if device == "cuda" else None,
                    device_map="auto",
                    trust_remote_code=True
                )
                if self.llm_tokenizer.pad_token is None:
                    self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
                log_and_print(f"Successfully loaded {model_name}")
            except Exception as e:
                log_and_print(f"Failed to load {model_name}: {e}", "error")
                self.llm_model = None
                self.llm_tokenizer = None
    
    def process_pdf(self, pdf_path: str, extraction_method: str, use_llm_with_marker: bool = False, progress_callback=None) -> Dict:
        """Process PDF through the entire pipeline with progress updates"""
        self.current_session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.processed_pdf_path = pdf_path
        self.debug_info = {}
        
        def update_progress(step: str, progress: float, status: str = ""):
            if progress_callback:
                progress_callback(progress, f"{step}: {status}")
            log_and_print(f"{step} ({progress*100:.1f}%): {status}")

        update_progress("Initialization", 0.05, "Clearing existing models")
        self.clear_all_models()

        update_progress("Extraction", 0.1, f"Extracting text using {extraction_method}")
        try:
            if extraction_method == "Docling":
                markdown_text, _ = self.extractor.extract_with_docling(pdf_path)
            else:
                markdown_text, _ = self.extractor.extract_with_marker(pdf_path, use_llm=use_llm_with_marker)
            self.debug_info['markdown'] = markdown_text
            update_progress("Extraction", 0.25, "Text extraction completed")
        except Exception as e:
            log_and_print(f"Extraction failed: {e}", "error")
            return {"status": "failed", "error": str(e)}

        update_progress("Captioning", 0.3, "Processing images and generating captions")
        try:
            captions_markdown, captions_data = self.captioner.process_document_images(pdf_path)
            self.debug_info['image_captions'] = captions_data
            document_text = markdown_text + captions_markdown
            self.debug_info['document_with_captions'] = document_text
            update_progress("Captioning", 0.4, f"Processed {len(captions_data)} images")
        except Exception as e:
            log_and_print(f"Captioning failed: {e}", "error")
            document_text = markdown_text

        update_progress("Chunking", 0.5, "Creating semantic chunks")
        try:
            self.initialize_embedding_model()
            if self.embed_model is None:
                raise Exception("Embedding model not available")
            
            document = Document(text=document_text, metadata={"source": pdf_path})
            splitter = SemanticSplitterNodeParser(
                buffer_size=1,
                breakpoint_percentile_threshold=95,
                embed_model=self.embed_model
            )
            nodes = splitter.get_nodes_from_documents([document])
            self.debug_info['chunks'] = [node.text for node in nodes]
            update_progress("Chunking", 0.7, f"Created {len(nodes)} semantic chunks")
        except Exception as e:
            log_and_print(f"Chunking failed: {e}", "error")
            return {"status": "failed", "error": str(e)}

        update_progress("Indexing", 0.8, "Creating FAISS vector store")
        try:
            dimension = 1024 # Qwen embedding dimension
            faiss_index = faiss.IndexFlatL2(dimension)
            vector_store = FaissVectorStore(faiss_index=faiss_index)
            storage_context = StorageContext.from_defaults(vector_store=vector_store)
            self.index = VectorStoreIndex(
                nodes=nodes,
                storage_context=storage_context,
                embed_model=self.embed_model
            )
            update_progress("Indexing", 1.0, "Vector index created successfully")
        except Exception as e:
            log_and_print(f"Indexing failed: {e}", "error")
            return {"status": "failed", "error": str(e)}

        # OPTIMIZATION: Removed LLM initialization from here. It will be loaded on the first query.
        log_and_print("PDF processing completed successfully. LLM will be loaded on first query.")
        return {"status": "success"}

    def query(self, question: str, top_k: int = 5) -> Dict:
        """Query the indexed document"""
        if self.index is None:
            return {"error": "No document indexed. Please process a PDF first.", "status": "failed"}

        try:
            # OPTIMIZATION: Ensure embedding and LLM models are loaded only when needed.
            self.initialize_embedding_model()
            self.initialize_llm_model()
            
            log_and_print("Retrieving relevant chunks...")
            retriever = VectorIndexRetriever(index=self.index, similarity_top_k=top_k)
            retrieved_nodes = retriever.retrieve(question)
            
            context_texts = [node.node.text for node in retrieved_nodes]
            self.debug_info['retrieved_chunks'] = context_texts
            context = "\n\n".join(context_texts)
            
            rag_prompt_content = f"""Please answer the following question based on the provided context. Be direct and specific.

Context:
{context}

Question: {question}"""
            self.debug_info['final_prompt'] = rag_prompt_content

            if self.llm_model and self.llm_tokenizer:
                log_and_print("Generating answer with LLM...")
                messages = [{"role": "user", "content": rag_prompt_content}]
                inputs = self.llm_tokenizer(
                    self.llm_tokenizer.apply_chat_template(
                        messages, tokenize=False, add_generation_prompt=True
                    ),
                    return_tensors="pt",
                    truncation=True,
                    max_length=7168
                ).to(self.llm_model.device)

                with torch.no_grad():
                    outputs = self.llm_model.generate(**inputs, max_new_tokens=1024, pad_token_id=self.llm_tokenizer.eos_token_id)
                full_answer = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=False)
                self.debug_info['full_llm_output'] = full_answer
                
                # Extract only the assistant's part of the response
                answer_part = full_answer.split("<|im_start|>assistant")[-1]
                if "<|im_end|>" in answer_part:
                    answer_part = answer_part.split("<|im_end|>")[0]

                log_and_print("Answer generation complete.")
                return {"answer": answer_part.strip(), "status": "success"}

            else:
                return {"answer": "LLM model not available.", "status": "failed"}
        except Exception as e:
            log_and_print(f"Query failed: {e}", "error")
            return {"status": "failed", "error": str(e)}


# Gradio Interface (No changes needed here, it correctly uses the class methods)
def create_gradio_interface():
    rag_system = PDFChatRAG()

    with gr.Blocks(title="PDF Chat RAG System", theme=gr.themes.Soft()) as interface:
        gr.Markdown("# PDF Chat RAG System")
        gr.Markdown("Upload a PDF, process it, and ask questions about its content.")

        with gr.Row():
            with gr.Column(scale=1):
                pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
                extraction_method = gr.Radio(choices=["Docling", "Marker"], value="Marker", label="Extraction Method")
                debug_mode = gr.Checkbox(label="Debug Mode", value=False)
                process_btn = gr.Button("Process PDF", variant="primary")
                memory_display = gr.Markdown(MemoryMonitor.get_memory_summary(), label="Memory Status")
                
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(label="Chat", height=600, type="tuples") # Keep type for compatibility
                with gr.Row():
                    question_input = gr.Textbox(label="Your Question", placeholder="Ask a question about the document...", scale=4)
                    submit_btn = gr.Button("Send", variant="primary", scale=1)
        
        with gr.Accordion("Debug Information", open=False, visible=False) as debug_accordion:
            with gr.Tabs() as debug_tabs:
                with gr.TabItem("PDF vs Markdown"):
                    with gr.Row():
                        pdf_display = gr.File(label="Uploaded PDF")
                        markdown_display = gr.Markdown(label="Extracted Markdown")
                with gr.TabItem("Image Captions"):
                    captions_display = gr.JSON(label="Image Captions")
                with gr.TabItem("Document with Captions"):
                    doc_with_captions_display = gr.Markdown(label="Document with Embedded Captions")
                with gr.TabItem("Chunks"):
                    chunks_display = gr.JSON(label="Text Chunks")
                with gr.TabItem("Retrieved Chunks"):
                    retrieved_chunks_display = gr.JSON(label="Retrieved Chunks for Last Query")
                with gr.TabItem("Final Prompt"):
                    prompt_display = gr.Markdown(label="Final Prompt to LLM")
                with gr.TabItem("Full LLM Output"):
                    full_output_display = gr.Markdown(label="Full Output from LLM (with thinking tokens)")

        # List of all output components for updates
        all_outputs = [
            chatbot, pdf_display, markdown_display, captions_display,
            doc_with_captions_display, chunks_display, retrieved_chunks_display,
            prompt_display, full_output_display
        ]
        
        def process_pdf_wrapper(pdf_file, extraction, is_debug, progress=gr.Progress()):
            if pdf_file is None:
                return [gr.update(value=[(None, "Please upload a PDF file.")])] + [gr.update()] * (len(all_outputs) - 1)
            
            rag_system.update_force_cpu(False) # Or add a checkbox for this
            
            def progress_callback(pct, desc):
                progress(pct, desc=desc)
        
            results = rag_system.process_pdf(pdf_file.name, extraction, progress_callback=progress_callback)
            
            if results["status"] == "success":
                chat_history = [
                    (None, f"Successfully processed '{os.path.basename(pdf_file.name)}'. You can now ask questions.")
                ]
                if is_debug:
                    # FIX: Wrapped all return values in gr.update for consistency.
                    return (
                        gr.update(value=chat_history),
                        gr.update(value=pdf_file.name),
                        gr.update(value=rag_system.debug_info.get('markdown', '')),
                        gr.update(value=rag_system.debug_info.get('image_captions', {})),
                        gr.update(value=rag_system.debug_info.get('document_with_captions', '')),
                        gr.update(value=rag_system.debug_info.get('chunks', [])),
                        gr.update(value=None), # Clear previous retrieval
                        gr.update(value=None), # Clear previous prompt
                        gr.update(value=None)  # Clear previous llm output
                    )
                # This part was already correct.
                return [gr.update(value=chat_history)] + [gr.update()] * (len(all_outputs) - 1)
            else:
                chat_history = [
                    (None, f"Failed to process PDF: {results['error']}")
                ]
                return [gr.update(value=chat_history)] + [gr.update()] * (len(all_outputs) - 1)
        
        def chat_wrapper(question, chat_history, is_debug):
            chat_history.append((question, "Thinking..."))
            # This first yield to show "Thinking..." is correct.
            yield [gr.update(value=chat_history)] + [gr.update()] * (len(all_outputs) - 1)
        
            results = rag_system.query(question)
        
            if results['status'] == 'success':
                chat_history[-1] = (question, results['answer'])
                if is_debug:
                    # FIX: Wrapped all yielded values in gr.update for consistency.
                    yield (
                        gr.update(value=chat_history),
                        gr.update(), # Keep PDF display
                        gr.update(), # Keep markdown
                        gr.update(), # Keep captions
                        gr.update(), # Keep doc with captions
                        gr.update(), # Keep chunks
                        gr.update(value=rag_system.debug_info.get('retrieved_chunks', [])),
                        gr.update(value=rag_system.debug_info.get('final_prompt', '')),
                        gr.update(value=rag_system.debug_info.get('full_llm_output', ''))
                    )
                else:
                    # This part was already correct.
                    yield [gr.update(value=chat_history)] + [gr.update()] * (len(all_outputs) - 1)
            else:
                chat_history[-1] = (question, f"Error: {results.get('error', 'An unknown error occurred.')}")
                # This part was already correct.
                yield [gr.update(value=chat_history)] + [gr.update()] * (len(all_outputs) - 1)
        
        
        process_btn.click(
            process_pdf_wrapper,
            inputs=[pdf_input, extraction_method, debug_mode],
            outputs=all_outputs
        )

        submit_btn.click(
            chat_wrapper,
            inputs=[question_input, chatbot, debug_mode],
            outputs=all_outputs
            
        ).then(lambda: gr.update(value=""), outputs=[question_input])
        
        question_input.submit(
            chat_wrapper,
            inputs=[question_input, chatbot, debug_mode],
            outputs=all_outputs
        ).then(lambda: gr.update(value=""), outputs=[question_input])

        debug_mode.change(
            lambda x: gr.update(visible=x),
            inputs=[debug_mode],
            outputs=[debug_accordion]
        )

    return interface


if __name__ == "__main__":
    log_and_print("Starting Gradio interface...")
    interface = create_gradio_interface()
    interface.launch(share=True, debug=True)

ModuleNotFoundError: No module named 'fitz'