<a href="https://colab.research.google.com/github/mucherlaananthalakshmi-web/Agentic-Project/blob/main/claims_management.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
# Enhanced RAG-Based Expense Claims Processing with Filename IDs and Token Tracking
# Complete workflow with detailed progress tracking and usage monitoring

import subprocess
import time
import os
import json
import pandas as pd
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict
import warnings
import uuid
warnings.filterwarnings('ignore')

print("🎯 Enhanced RAG-Based Expense Claims Processing System")
print("🔍 Features: Filename-based IDs + Token Usage Tracking")
print("=" * 70)

# ================================
# STEP 1: SYSTEM SETUP & INSTALLATION
# ================================

print("📦 Installing dependencies...")

# Core RAG dependencies
subprocess.run([
    "pip", "install", "-q",
    "ollama",
    "langchain-ollama",
    "langchain-core",
    "langchain-community",
    "chromadb>=0.4.0",
    "unstructured[pdf]>=0.10.0",
    "sentence-transformers",
    "pandas",
    "pillow",
    "python-dateutil",
    "pydantic",
    "langgraph"
], check=True)

print("✅ Dependencies installed")

# ================================
# STEP 1.5: PRE-DOWNLOAD UNSTRUCTURED MODELS
# ================================

print("\n📥 Pre-downloading UnstructuredIO models...")
print("=" * 70)

def predownload_unstructured_models():
    """Pre-download all necessary UnstructuredIO models"""

    try:

        # Import the layout model which triggers YOLO download
        from unstructured.partition.pdf import partition_pdf
        from unstructured.partition.auto import partition

        # Try to initialize the model by importing layout detection
        try:
            from unstructured.models import detectron2_onnx
        except:
            pass

        try:
            from unstructured.partition.utils.constants import Source
            from unstructured.documents.elements import Element

        except:
            pass

        # Force download by attempting to process a dummy image
        # This triggers YOLO model download

        # Create a minimal dummy image to trigger model download
        from PIL import Image
        import io

        # Create a simple white image
        img = Image.new('RGB', (100, 100), color='white')
        img_bytes = io.BytesIO()
        img.save(img_bytes, format='PNG')
        img_bytes.seek(0)

        # Save temporarily
        temp_path = "/tmp/dummy_image.png"
        with open(temp_path, "wb") as f:
            f.write(img_bytes.getvalue())

        # Try to partition it to trigger model downloads
        try:
            from unstructured.partition.auto import partition
            elements = partition(filename=temp_path)
            print("✅ YOLO model (yolox_l0.05.onnx) downloaded successfully")
        except Exception as e:
            print(f"⚠️ Model download triggered, may complete in background: {e}")

        # Clean up
        if os.path.exists(temp_path):
            os.remove(temp_path)

        # Download models for Excel/CSV if available
        try:
            from unstructured.partition.xlsx import partition_xlsx
            from unstructured.partition.csv import partition_csv
        except ImportError:
            print("ℹ️ Excel/CSV models not required or already available")

        # Download models for table extraction if available

        try:
            from unstructured.partition.pdf import partition_pdf_with_table_extraction

        except:
            print("ℹ️ Table extraction models not required")

        print("\n✅ All UnstructuredIO models pre-downloaded successfully!")

    except Exception as e:
        print(f"⚠️ Some models may download on first use: {e}")


# Run the pre-download function
predownload_unstructured_models()

print("\n" + "=" * 70)

# Install and start Ollama
print("📦 Installing Ollama...")
try:
    result = subprocess.run(
        ["curl", "-fsSL", "https://ollama.com/install.sh"],
        capture_output=True, text=True, check=True
    )
    subprocess.run(["sh"], input=result.stdout, text=True, check=True)
    print("✅ Ollama installed successfully")
except subprocess.CalledProcessError as e:
    print(f"❌ Error installing Ollama: {e}")

# Start Ollama server
print("🔧 Starting Ollama server...")
os.environ['OLLAMA_HOST'] = '127.0.0.1:11434'

try:
    ollama_process = subprocess.Popen(
        ["ollama", "serve"],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        preexec_fn=os.setsid
    )
    print(f"✅ Ollama server started (PID: {ollama_process.pid})")
    time.sleep(10)
except Exception as e:
    print(f"❌ Error starting server: {e}")

# Download models
print("📥 Downloading Ollama models...")
EMBEDDING_MODEL = "nomic-embed-text"
TEXT_MODEL = "gemma3:1b"

models = [EMBEDDING_MODEL, TEXT_MODEL]
for model in models:
    print(f"📥 Downloading {model}...")
    try:
        result = subprocess.run(["ollama", "pull", model],
                              capture_output=True, text=True, timeout=600)
        if result.returncode == 0:
            print(f"✅ {model} downloaded successfully")
        else:
            print(f"❌ Error downloading {model}")
    except Exception as e:
        print(f"❌ Error downloading {model}: {e}")

# Test connection
import ollama
try:
    response = ollama.chat(model=TEXT_MODEL, messages=[{'role': 'user', 'content': 'test'}])
    print("✅ Ollama connection working")
except Exception as e:
    print(f"❌ Connection error: {e}")

print("\n✅ SETUP COMPLETE - All models pre-downloaded!")
print("=" * 70)

🎯 Enhanced RAG-Based Expense Claims Processing System
🔍 Features: Filename-based IDs + Token Usage Tracking
📦 Installing dependencies...
✅ Dependencies installed

📥 Pre-downloading UnstructuredIO models...


yolox_l0.05.onnx:   0%|          | 0.00/217M [00:00<?, ?B/s]

✅ YOLO model (yolox_l0.05.onnx) downloaded successfully
ℹ️ Excel/CSV models not required or already available
ℹ️ Table extraction models not required

✅ All UnstructuredIO models pre-downloaded successfully!

📦 Installing Ollama...
✅ Ollama installed successfully
🔧 Starting Ollama server...
✅ Ollama server started (PID: 6803)
📥 Downloading Ollama models...
📥 Downloading nomic-embed-text...
✅ nomic-embed-text downloaded successfully
📥 Downloading gemma3:1b...
✅ gemma3:1b downloaded successfully
✅ Ollama connection working

✅ SETUP COMPLETE - All models pre-downloaded!


In [14]:
# ================================
# STEP 2: TOKEN USAGE TRACKING
# ================================

class TokenUsageTracker:
    """Track token usage across all LLM calls"""

    def __init__(self):
        self.call_history = []
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.total_tokens = 0
        self.call_count = 0

    def track_call(self, operation: str, filename: str, task: str, response) -> Dict[str, Any]:
        """Track a single LLM call and extract usage info"""

        usage_info = {
            "operation": operation,
            "filename": filename,
            "task": task,
            "timestamp": datetime.now().isoformat(),
            "input_tokens": 0,
            "output_tokens": 0,
            "total_tokens": 0,
            "duration_ms": 0
        }

        # Extract token usage from response
        try:
            if hasattr(response, 'usage_metadata') and response.usage_metadata:
                usage_info["input_tokens"] = response.usage_metadata.get('input_tokens', 0)
                usage_info["output_tokens"] = response.usage_metadata.get('output_tokens', 0)
                usage_info["total_tokens"] = response.usage_metadata.get('total_tokens', 0)

            # Fallback: try response_metadata
            elif hasattr(response, 'response_metadata') and response.response_metadata:
                metadata = response.response_metadata
                usage_info["input_tokens"] = metadata.get('prompt_eval_count', 0)
                usage_info["output_tokens"] = metadata.get('eval_count', 0)
                usage_info["total_tokens"] = usage_info["input_tokens"] + usage_info["output_tokens"]
                usage_info["duration_ms"] = metadata.get('total_duration', 0) // 1000000  # Convert to ms

        except Exception as e:
            print(f"⚠️ Could not extract token usage: {e}")

        # Update totals
        self.total_input_tokens += usage_info["input_tokens"]
        self.total_output_tokens += usage_info["output_tokens"]
        self.total_tokens += usage_info["total_tokens"]
        self.call_count += 1

        # Store call history
        self.call_history.append(usage_info)

        # Print usage info
        self.print_usage_info(usage_info)

        return usage_info


    def print_usage_info(self, usage_info: Dict[str, Any]):
        """Print formatted usage information"""
        print(f"📊 TOKEN USAGE - {usage_info['operation']} | {usage_info['filename']} | {usage_info['task']}")
        print(f"   📥 Input: {usage_info['input_tokens']} tokens")
        print(f"   📤 Output: {usage_info['output_tokens']} tokens")
        print(f"   🔢 Total: {usage_info['total_tokens']} tokens")
        if usage_info['duration_ms'] > 0:
            print(f"   ⏱️ Duration: {usage_info['duration_ms']}ms")
        print()

    def print_summary(self):
        """Print overall token usage summary"""
        print("=" * 60)
        print("📊 TOTAL TOKEN USAGE SUMMARY")
        print("=" * 60)
        print(f"🔢 Total LLM Calls: {self.call_count}")
        print(f"📥 Total Input Tokens: {self.total_input_tokens:,}")
        print(f"📤 Total Output Tokens: {self.total_output_tokens:,}")
        print(f"🎯 Grand Total Tokens: {self.total_tokens:,}")

        if self.call_count > 0:
            print(f"📊 Average per call: {self.total_tokens/self.call_count:.1f} tokens")
        print()

# Global token tracker
token_tracker = TokenUsageTracker()

print("\n✅ SETUP COMPLETE - TokenUsageTracker!")
print("=" * 70)


✅ SETUP COMPLETE - TokenUsageTracker!


In [15]:
# ================================
# STEP 3: FILENAME-BASED DOCUMENT MANAGEMENT
# ================================

@dataclass
class ClaimDocument:
    """Document with filename-based identification"""
    filename: str  # Primary identifier (no more UUIDs!)
    file_path: str
    raw_text: str
    chunks: List[str]
    metadata: Dict[str, Any]
    processed_timestamp: datetime

class FilenameBasedDocumentManager:
    """Manages documents using filenames as primary identifiers"""

    def __init__(self):
        self.documents_registry = {}  # filename -> ClaimDocument
        self.chunk_to_file_map = {}  # chunk_id -> filename

    def register_document(self, file_path: str, raw_text: str) -> str:
        """Register document using filename as ID"""

        filename = Path(file_path).stem  # Get filename without extension

        print(f"📋 REGISTERING DOCUMENT: {filename}")
        print(f"   📁 Source: {Path(file_path).name}")
        print(f"   📄 Text length: {len(raw_text)} characters")

        # Create isolated chunks for this document
        chunks = self.create_document_chunks(raw_text, filename)

        claim_doc = ClaimDocument(
            filename=filename,
            file_path=file_path,
            raw_text=raw_text,
            chunks=chunks,
            metadata={
                "file_name": Path(file_path).name,
                "file_extension": Path(file_path).suffix,
                "chunk_count": len(chunks),
                "source": "ocr_extraction"
            },
            processed_timestamp=datetime.now()
        )

        self.documents_registry[filename] = claim_doc

        # Update chunk mapping
        for i, chunk in enumerate(chunks):
            chunk_id = f"{filename}_chunk_{i}"
            self.chunk_to_file_map[chunk_id] = filename

        print(f"✅ Document registered: {filename} with {len(chunks)} chunks")
        return filename

    def create_document_chunks(self, text: str, filename: str) -> List[str]:
        """Create chunks with filename-specific context isolation"""

        print(f"🔪 CHUNKING DOCUMENT: {filename}")

        lines = text.split('\n')
        chunks = []
        current_chunk = []
        current_length = 0
        max_chunk_size = 500

        # Expense document section markers
        section_markers = [
            'total', 'amount', 'date', 'vendor', 'receipt', 'invoice',
            'item', 'quantity', 'price', 'tax', 'subtotal'
        ]

        for line in lines:
            line = line.strip()
            if not line:
                continue

            line_length = len(line)
            is_section_start = any(marker in line.lower() for marker in section_markers)

            if (current_length + line_length > max_chunk_size) or \
               (is_section_start and current_chunk and current_length > 200):

                chunk_text = '\n'.join(current_chunk)
                if chunk_text.strip():
                    # Add filename isolation metadata to chunk
                    isolated_chunk = f"[DOCUMENT: {filename}]\n{chunk_text}"
                    chunks.append(isolated_chunk)

                current_chunk = [line]
                current_length = line_length
            else:
                current_chunk.append(line)
                current_length += line_length + 1

        # Add final chunk
        if current_chunk:
            chunk_text = '\n'.join(current_chunk)
            if chunk_text.strip():
                isolated_chunk = f"[DOCUMENT: {filename}]\n{chunk_text}"
                chunks.append(isolated_chunk)

        print(f"   🔪 Created {len(chunks)} chunks (avg {len(text)//len(chunks) if chunks else 0} chars each)")
        return chunks



    def get_document_context(self, filename: str) -> Optional[ClaimDocument]:
        """Get complete context for a specific document"""
        return self.documents_registry.get(filename)

    def list_all_documents(self) -> List[str]:
        """List all registered filenames"""
        return list(self.documents_registry.keys())

print("\n✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!")
print("=" * 70)


✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!


In [16]:
# ================================
# STEP 4: ENHANCED OCR PROCESSOR
# ================================

class EnhancedOCRProcessor:
    """OCR processing with detailed progress tracking"""

    def __init__(self):
        self.supported_formats = ['.pdf', '.jpg', '.jpeg', '.png', '.tiff']

    def extract_text_from_document(self, file_path: str) -> str:
        """Extract text with detailed progress tracking"""

        filename = Path(file_path).name
        print(f"🔍 EXTRACTING TEXT FROM: {filename}")
        print(f"   📁 Full path: {file_path}")
        print(f"   📊 File size: {Path(file_path).stat().st_size / 1024:.1f} KB")

        try:
            from unstructured.partition.auto import partition

            print(f"   🔄 Processing with UnstructuredIO...")

            # Process document with UnstructuredIO
            elements = partition(filename=file_path)

            print(f"   📋 Found {len(elements)} document elements")

            # Extract text from all elements
            full_text = ""
            for i, element in enumerate(elements):
                if hasattr(element, 'text') and element.text:
                    full_text += element.text + "\n"
                    if i < 5:  # Show first few elements
                        print(f"     Element {i+1}: {element.text[:50]}...")

            # Clean and normalize text
            full_text = self.clean_extracted_text(full_text)

            print(f"   ✅ Extracted {len(full_text)} characters")
            print(f"   📝 Text preview: {full_text[:100]}...")
            return full_text

        except Exception as e:
            print(f"   ❌ OCR extraction failed: {e}")
            return ""

    def clean_extracted_text(self, text: str) -> str:
        """Clean extracted text with progress info"""
        if not text:
            return ""

        original_length = len(text)
        lines = text.split('\n')
        cleaned_lines = []

        for line in lines:
            line = line.strip()
            if line and len(line) > 2:
                cleaned_lines.append(line)

        cleaned_text = '\n'.join(cleaned_lines)
        print(f"   🧹 Cleaned: {original_length} → {len(cleaned_text)} chars ({len(cleaned_lines)} lines)")

        return cleaned_text

print("\n✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!")
print("=" * 70)


✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!


In [17]:
# ================================
# STEP 5: ENHANCED VECTOR STORE
# ================================

class EnhancedIsolatedVectorStore:
    """ChromaDB with enhanced tracking and filename-based isolation"""

    def __init__(self, embedding_model: str = EMBEDDING_MODEL):
        import chromadb

        self.embedding_model = embedding_model

        print(f"🗄️ INITIALIZING VECTOR STORE")
        print(f"   🤖 Embedding Model: {embedding_model}")

        # Initialize ChromaDB client using new API
        self.client = chromadb.PersistentClient(path="./chroma_db")

        # Create collection
        self.collection = self.client.get_or_create_collection(
            name="filename_based_expense_claims",
            metadata={"hnsw:space": "cosine"}
        )

        print(f"   ✅ ChromaDB initialized")

    def embed_text(self, text: str, filename: str = "unknown") -> List[float]:
        """Generate embeddings with progress tracking"""

        print(f"🔢 GENERATING EMBEDDING: {filename}")
        print(f"   📝 Text length: {len(text)} chars")

        try:
            response = ollama.embeddings(model=self.embedding_model, prompt=text)
            embedding = response['embedding']
            print(f"   ✅ Generated {len(embedding)}-dimensional embedding")
            return embedding
        except Exception as e:
            print(f"   ❌ Embedding error: {e}")
            return []

    def add_document_chunks(self, filename: str, chunks: List[str], metadata: Dict[str, Any]):
        """Add chunks for a specific document with detailed tracking"""

        print(f"📚 ADDING CHUNKS TO VECTOR STORE: {filename}")
        print(f"   📊 Number of chunks: {len(chunks)}")

        embeddings = []
        chunk_ids = []
        metadatas = []

        for i, chunk in enumerate(chunks):
            print(f"   🔄 Processing chunk {i+1}/{len(chunks)}")

            # Generate embedding
            embedding = self.embed_text(chunk, f"{filename}_chunk_{i}")
            if not embedding:
                print(f"   ⚠️ Skipping chunk {i+1} - no embedding generated")
                continue

            chunk_id = f"{filename}_chunk_{i}"
            chunk_metadata = {
                **metadata,
                "filename": filename,
                "chunk_index": i,
                "chunk_id": chunk_id,
                "isolated": True
            }

            embeddings.append(embedding)
            chunk_ids.append(chunk_id)
            metadatas.append(chunk_metadata)

        # Add to ChromaDB
        if embeddings:
            self.collection.add(
                embeddings=embeddings,
                documents=chunks,
                metadatas=metadatas,
                ids=chunk_ids
            )

            print(f"   ✅ Added {len(embeddings)} chunks to vector store")
        else:
            print(f"   ❌ No chunks added - all embeddings failed")

    def query_document_specific(self, query: str, filename: str, n_results: int = 3) -> Dict[str, Any]:
        """Query specific document only - prevents cross-contamination"""

        print(f"🔍 QUERYING VECTOR STORE: {filename}")
        print(f"   ❓ Query: {query}")
        print(f"   📊 Requesting {n_results} results")

        query_embedding = self.embed_text(query, f"query_{filename}")
        if not query_embedding:
            return {"error": "Failed to generate query embedding"}

        # Query with filename filter to ensure isolation
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            where={"filename": filename},  # CRITICAL: Isolates to specific document
            include=["documents", "metadatas", "distances"]
        )

        print(f"   ✅ Found {len(results['documents'][0]) if results['documents'] else 0} relevant chunks")

        return {
            "documents": results['documents'][0] if results['documents'] else [],
            "metadatas": results['metadatas'][0] if results['metadatas'] else [],
            "distances": results['distances'][0] if results['distances'] else [],
            "filename": filename
        }

    def get_collection_stats(self) -> Dict[str, Any]:
        """Get detailed statistics about stored documents"""

        print("📊 GENERATING COLLECTION STATISTICS")

        count = self.collection.count()

        # Get unique filenames
        all_metadata = self.collection.get(include=["metadatas"])
        filenames = set()
        if all_metadata['metadatas']:
            for meta in all_metadata['metadatas']:
                if 'filename' in meta:
                    filenames.add(meta['filename'])

        stats = {
            "total_chunks": count,
            "unique_documents": len(filenames),
            "filenames": list(filenames)
        }

        print(f"   📚 Total chunks: {stats['total_chunks']}")
        print(f"   📄 Unique documents: {stats['unique_documents']}")
        print(f"   📝 Documents: {', '.join(stats['filenames'])}")

        return stats

print("\n✅ SETUP COMPLETE - ENHANCED OCR PROCESSOR!")
print("=" * 70)



✅ SETUP COMPLETE - ENHANCED OCR PROCESSOR!


In [18]:
# ================================
# STEP 6: ENHANCED EXPENSE TASK MANAGER
# ================================

class EnhancedExpenseTaskManager:
    """Manages predefined expense extraction tasks with better tracking"""

    def __init__(self):
        self.predefined_tasks = {
            "extract_amount": {
                "query": "total amount due payment cost price sum money dollar",
                "description": "Extract the total amount from this expense document",
                "expected_format": "numeric value with currency"
            },
            "extract_date": {
                "query": "date transaction purchase invoice receipt timestamp when",
                "description": "Extract the date from this expense document",
                "expected_format": "date in YYYY-MM-DD format"
            },
            "extract_vendor": {
                "query": "vendor merchant company business supplier store restaurant hotel",
                "description": "Extract vendor/merchant name from this expense document",
                "expected_format": "company or business name"
            },
            "extract_category": {
                "query": "category type classification expense kind service product item",
                "description": "Determine expense category from this document",
                "expected_format": "expense category classification"
            },
            "extract_items": {
                "query": "items products services line items purchases description details",
                "description": "Extract itemized details from this expense document",
                "expected_format": "list of items or services"
            },
            "extract_tax": {
                "query": "tax VAT GST sales tax tax rate percentage",
                "description": "Extract tax information from this expense document",
                "expected_format": "tax amount and rate"
            }
        }

    def get_task_info(self, task_name: str) -> Dict[str, str]:
        """Get complete task information"""
        return self.predefined_tasks.get(task_name, {})

    def list_available_tasks(self) -> List[str]:
        """List all available extraction tasks"""
        return list(self.predefined_tasks.keys())

print("\n✅ SETUP COMPLETE - ENHANCED EXPENSE TASK MANAGER!")
print("=" * 70)


✅ SETUP COMPLETE - ENHANCED EXPENSE TASK MANAGER!


In [19]:
# ================================
# STEP 7: ENHANCED RAG PROCESSOR
# ================================

class EnhancedRAGExpenseProcessor:
    """RAG-based expense processor with comprehensive tracking"""

    def __init__(self, text_model: str = TEXT_MODEL):
        from langchain_ollama import ChatOllama

        print(f"🚀 INITIALIZING RAG EXPENSE PROCESSOR")
        print(f"   🤖 Text Model: {text_model}")

        self.llm = ChatOllama(
            model=text_model,
            temperature=0.1,
            base_url="http://127.0.0.1:11434"
        )

        self.vector_store = EnhancedIsolatedVectorStore()
        self.task_manager = EnhancedExpenseTaskManager()
        self.document_manager = FilenameBasedDocumentManager()
        self.ocr_processor = EnhancedOCRProcessor()

        print("   ✅ All components initialized")

    def ingest_document(self, file_path: str) -> str:
        """INGESTION PHASE: Process document and store in vector DB"""

        filename = Path(file_path).name
        print("\n" + "="*70)
        print(f"🔄 INGESTION PHASE STARTING")
        print(f"📄 FILE: {filename}")
        print(f"📁 PATH: {file_path}")
        print("="*70)

        # Step 1: OCR extraction
        raw_text = self.ocr_processor.extract_text_from_document(file_path)
        if not raw_text:
            print("❌ INGESTION FAILED: No text extracted")
            return None

        # Step 2: Register document with filename-based system
        filename_id = self.document_manager.register_document(file_path, raw_text)

        # Step 3: Get document context
        document = self.document_manager.get_document_context(filename_id)

        # Step 4: Store in vector database
        metadata = {
            **document.metadata,
            "ingestion_timestamp": datetime.now().isoformat()
        }

        self.vector_store.add_document_chunks(
            filename=filename_id,
            chunks=document.chunks,
            metadata=metadata
        )

        print(f"✅ INGESTION COMPLETED: {filename_id}")
        print("="*70)
        return filename_id

    def process_expense_task(self, filename: str, task_name: str) -> Dict[str, Any]:
        """INFERENCE PHASE: Process specific task for document"""

        print(f"\n🎯 INFERENCE PHASE STARTING")
        print(f"📄 DOCUMENT: {filename}")
        print(f"🎯 TASK: {task_name}")
        print("-" * 50)

        # Step 1: Get task information
        task_info = self.task_manager.get_task_info(task_name)
        if not task_info:
            return {"error": f"Unknown task: {task_name}"}

        task_query = task_info.get("query", "")
        task_description = task_info.get("description", "")

        print(f"📋 Task Description: {task_description}")
        print(f"🔍 Search Query: {task_query}")

        # Step 2: Retrieve relevant chunks (ISOLATED to this document)
        retrieval_results = self.vector_store.query_document_specific(
            query=task_query,
            filename=filename,
            n_results=3
        )

        if retrieval_results.get("error"):
            return retrieval_results

        # Step 3: Prepare optimized context
        context = self.optimize_context(retrieval_results, task_name)

        # Step 4: Generate response with LLM (WITH TOKEN TRACKING)
        response_text, token_usage = self.generate_task_response_with_tracking(
            context, task_name, task_description, filename
        )

        result = {
            "task": task_name,
            "filename": filename,
            "response": response_text,
            "context_chunks_used": len(retrieval_results["documents"]),
            # "confidence": self.calculate_confidence(retrieval_results),
            "token_usage": token_usage
        }

        print(f"✅ INFERENCE COMPLETED: {task_name} for {filename}")
        print("-" * 50)

        return result

    def optimize_context(self, retrieval_results: Dict[str, Any], task_name: str) -> str:
        """CONTEXT OPTIMIZATION: Reduce context overloading"""

        documents = retrieval_results.get("documents", [])
        distances = retrieval_results.get("distances", [])
        filename = retrieval_results.get("filename", "unknown")

        print(f"🔧 OPTIMIZING CONTEXT: {filename}")
        print(f"   📊 Raw chunks: {len(documents)}")

        if not documents:
            return "No relevant context found"

        # Rank documents by relevance
        doc_scores = list(zip(documents, distances))
        doc_scores.sort(key=lambda x: x[1])

        optimized_chunks = []
        total_length = 0
        max_context_length = 1500

        for i, (doc, score) in enumerate(doc_scores):
            # Remove document prefix from chunks
            clean_doc = doc.replace(f"[DOCUMENT: {filename}]\n", "")

            if total_length + len(clean_doc) <= max_context_length:
                optimized_chunks.append(clean_doc)
                total_length += len(clean_doc)
                print(f"   ✅ Chunk {i+1}: {len(clean_doc)} chars (relevance: {score:.3f})")
            else:
                remaining_space = max_context_length - total_length
                if remaining_space > 100:
                    truncated = clean_doc[:remaining_space] + "..."
                    optimized_chunks.append(truncated)
                    print(f"   ✂️ Chunk {i+1}: truncated to {len(truncated)} chars")
                break

        context = "\n\n---\n\n".join(optimized_chunks)
        print(f"   🎯 Final context: {len(context)} chars from {len(optimized_chunks)} chunks")

        return context

    def generate_task_response_with_tracking(self, context: str, task_name: str, task_description: str, filename: str) -> Tuple[str, Dict[str, Any]]:
        """Generate LLM response with token usage tracking"""

        print(f"🤖 GENERATING LLM RESPONSE: {task_name} | {filename}")

        prompt = f"""You are an expert expense analyst. {task_description}

CONTEXT FROM EXPENSE DOCUMENT ({filename}):
{context}

TASK: {task_name}
INSTRUCTION: {task_description}

Based ONLY on the context provided above, extract the requested information. Be precise and factual. If the information is not clearly present in the context, state "Information not found in provided context."

Response:"""

        print(f"   📝 Prompt length: {len(prompt)} characters")

        try:
            response = self.llm.invoke(prompt)

            # Track token usage
            token_usage = token_tracker.track_call("llm_inference", filename, task_name, response)

            return response.content.strip(), token_usage

        except Exception as e:
            error_msg = f"Error generating response: {e}"
            print(f"   ❌ {error_msg}")
            return error_msg, {}

    # def calculate_confidence(self, retrieval_results: Dict[str, Any]) -> float:
    #     """Calculate confidence based on retrieval quality"""
    #     distances = retrieval_results.get("distances", [])
    #     if not distances:
    #         return 0.0

    #     avg_distance = sum(distances) / len(distances)
    #     confidence = max(0.0, 1.0 - avg_distance)
    #     return round(confidence, 3)

    def process_all_tasks_for_document(self, filename: str) -> Dict[str, Any]:
        """Process all predefined tasks for a document"""

        print(f"\n📊 PROCESSING ALL TASKS FOR: {filename}")
        print("="*50)

        tasks = self.task_manager.list_available_tasks()
        results = {}

        for i, task in enumerate(tasks, 1):
            print(f"\n[{i}/{len(tasks)}] Starting task: {task}")
            result = self.process_expense_task(filename, task)
            results[task] = result

        print(f"\n✅ ALL TASKS COMPLETED FOR: {filename}")
        return results

print("\n✅ SETUP COMPLETE - ENHANCED RAG PROCESSOR!")
print("=" * 70)


✅ SETUP COMPLETE - ENHANCED RAG PROCESSOR!


In [20]:
# ================================
# STEP 8: ENHANCED WORKFLOW
# ================================

from langgraph.graph import StateGraph
from typing import TypedDict

class EnhancedRAGWorkflowState(TypedDict):
    """Enhanced state for RAG workflow"""
    file_paths: List[str]
    current_file_index: int
    processed_filenames: List[str]
    current_filename: str
    task_results: Dict[str, Dict[str, Any]]
    workflow_status: str
    error: Optional[str]

def create_enhanced_rag_workflow() -> StateGraph:
    """Create enhanced LangGraph workflow for RAG processing"""

    processor = EnhancedRAGExpenseProcessor()

    def enhanced_ingestion_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced ingestion with detailed tracking"""

        print("\n" + "🔄 WORKFLOW: ENHANCED INGESTION PHASE STARTING")
        print("="*70)

        file_paths = state.get("file_paths", [])
        processed_filenames = []

        for i, file_path in enumerate(file_paths, 1):
            print(f"\n[{i}/{len(file_paths)}] Processing file: {Path(file_path).name}")

            try:
                filename = processor.ingest_document(file_path)
                if filename:
                    processed_filenames.append(filename)
                    print(f"✅ Successfully ingested: {filename}")
                else:
                    print(f"❌ Failed to ingest: {Path(file_path).name}")

            except Exception as e:
                print(f"❌ Error ingesting {Path(file_path).name}: {e}")

        state["processed_filenames"] = processed_filenames
        state["workflow_status"] = "ingestion_complete" if processed_filenames else "ingestion_failed"

        print(f"\n📊 INGESTION PHASE COMPLETED")
        print(f"   ✅ Successfully processed: {len(processed_filenames)} files")
        print(f"   ❌ Failed: {len(file_paths) - len(processed_filenames)} files")

        return state

    def enhanced_task_processing_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced task processing with detailed tracking"""

        print("\n🎯 WORKFLOW: ENHANCED TASK PROCESSING PHASE STARTING")
        print("="*70)

        processed_filenames = state.get("processed_filenames", [])
        task_results = {}

        for i, filename in enumerate(processed_filenames, 1):
            print(f"\n[{i}/{len(processed_filenames)}] Processing tasks for: {filename}")

            try:
                results = processor.process_all_tasks_for_document(filename)
                task_results[filename] = results
                print(f"✅ Completed all tasks for: {filename}")

            except Exception as e:
                print(f"❌ Error processing tasks for {filename}: {e}")
                task_results[filename] = {"error": str(e)}

        state["task_results"] = task_results
        state["workflow_status"] = "processing_complete"

        print(f"\n📊 TASK PROCESSING PHASE COMPLETED")
        print(f"   📄 Documents processed: {len(task_results)}")

        return state

    def enhanced_results_compilation_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced results compilation with detailed stats and CSV export"""

        print("\n📊 WORKFLOW: ENHANCED RESULTS COMPILATION STARTING")
        print("="*70)

        task_results = state.get("task_results", {})

        # Compile detailed statistics
        total_documents = len(task_results)
        successful_documents = sum(1 for r in task_results.values() if "error" not in r)

        # Save results with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = f"enhanced_rag_expense_results_{timestamp}.json"
        csv_file = f"enhanced_rag_expense_results_{timestamp}.csv"

        # Create comprehensive results package
        comprehensive_results = {
            "timestamp": timestamp,
            "summary": {
                "total_documents": total_documents,
                "successful_documents": successful_documents,
                "failed_documents": total_documents - successful_documents
            },
            "token_usage_summary": {
                "total_calls": token_tracker.call_count,
                "total_input_tokens": token_tracker.total_input_tokens,
                "total_output_tokens": token_tracker.total_output_tokens,
                "total_tokens": token_tracker.total_tokens
            },
            "document_results": task_results,
            "token_call_history": token_tracker.call_history
        }

        # Save JSON results
        with open(results_file, 'w') as f:
            json.dump(comprehensive_results, f, indent=2, default=str)

        print(f"💾 JSON RESULTS SAVED TO: {results_file}")

        # Create CSV from results
        csv_rows = []

        for filename, doc_results in task_results.items():
            if "error" in doc_results:
                # Add error row
                csv_rows.append({
                    "filename": filename,
                    "task": "error",
                    "response": doc_results["error"],
                    "context_chunks_used": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "total_tokens": 0
                })
            else:
                # Process each task for this document
                for task_name, task_result in doc_results.items():
                    if isinstance(task_result, dict):
                        token_usage = task_result.get("token_usage", {})
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": task_result.get("response", ""),
                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                            "input_tokens": token_usage.get("input_tokens", 0),
                            "output_tokens": token_usage.get("output_tokens", 0),
                            "total_tokens": token_usage.get("total_tokens", 0)
                        })

        # Save CSV
        if csv_rows:
            df = pd.DataFrame(csv_rows)

            # Reorder columns for better readability
            column_order = [
                "filename", "task", "response",
                "context_chunks_used", "input_tokens",
                "output_tokens", "total_tokens"
            ]
            df = df[column_order]

            # Save to CSV
            df.to_csv(csv_file, index=False, encoding='utf-8')
            print(f"💾 CSV RESULTS SAVED TO: {csv_file}")

            # Display summary statistics
            print(f"\n📊 CSV Summary:")
            print(f"   📄 Total rows: {len(df)}")
            print(f"   📁 Documents: {df['filename'].nunique()}")
            print(f"   🎯 Tasks per document: {df.groupby('filename').size().mean():.1f}")
            print(f"   🔢 Total tokens used: {df['total_tokens'].sum():,}")

        # Also save a summary CSV with aggregated data per document
        summary_csv_file = f"enhanced_rag_expense_summary_{timestamp}.csv"
        summary_rows = []

        for filename, doc_results in task_results.items():
            if "error" not in doc_results:
                row = {"filename": filename}

                # Extract key information from each task
                for task_name in ["extract_amount", "extract_date", "extract_vendor",
                                "extract_category", "extract_tax"]:
                    if task_name in doc_results:
                        response = doc_results[task_name].get("response", "")
                        # Clean the response (take first line or first 100 chars)
                        cleaned = response.split('\n')[0][:100] if response else ""
                        row[task_name] = cleaned

                # Add token totals
                total_tokens = sum(
                    doc_results.get(task, {}).get("token_usage", {}).get("total_tokens", 0)
                    for task in doc_results if isinstance(doc_results.get(task), dict)
                )
                row["total_tokens_used"] = total_tokens

                summary_rows.append(row)

        if summary_rows:
            summary_df = pd.DataFrame(summary_rows)
            summary_df.to_csv(summary_csv_file, index=False, encoding='utf-8')
            print(f"💾 SUMMARY CSV SAVED TO: {summary_csv_file}")

        print(f"\n📊 FILES SAVED:")
        print(f"   📄 Detailed JSON: {results_file}")
        print(f"   📄 Detailed CSV: {csv_file}")
        print(f"   📄 Summary CSV: {summary_csv_file}")
        print(f"   📄 Documents processed: {successful_documents}/{total_documents}")

        # Print token usage summary
        token_tracker.print_summary()

        state["workflow_status"] = "complete"
        return state

    # Build enhanced workflow
    workflow = StateGraph(EnhancedRAGWorkflowState)

    workflow.add_node("enhanced_ingestion", enhanced_ingestion_node)
    workflow.add_node("enhanced_task_processing", enhanced_task_processing_node)
    workflow.add_node("enhanced_results_compilation", enhanced_results_compilation_node)

    workflow.set_entry_point("enhanced_ingestion")
    workflow.add_edge("enhanced_ingestion", "enhanced_task_processing")
    workflow.add_edge("enhanced_task_processing", "enhanced_results_compilation")
    workflow.set_finish_point("enhanced_results_compilation")

    return workflow.compile()

print("\n✅ SETUP COMPLETE - ENHANCED WORKFLOW!")
print("=" * 70)


✅ SETUP COMPLETE - ENHANCED WORKFLOW!


In [21]:
# ================================
# STEP 9: DEMONSTRATION WITH ENHANCED TRACKING
# ================================

print("\n" + "="*70)
print("🚀 ENHANCED RAG DEMONSTRATION STARTING")
print("📊 Features: Filename IDs + Token Tracking + Detailed Progress")
print("="*70)

# Initialize enhanced system
processor = EnhancedRAGExpenseProcessor()

# Check for sample documents
samples_dir = Path("/kaggle/input/hotel-bills")
if samples_dir.exists():
    sample_files = []
    for ext in ['.pdf', '.jpg', '.jpeg', '.png']:
        sample_files.extend(list(samples_dir.rglob(f"*{ext}")))

    # Use first few files for demo
    # demo_files = sample_files[:min(3, len(sample_files))]
    demo_files = sample_files

    if demo_files:
        print(f"📄 FOUND {len(demo_files)} SAMPLE DOCUMENTS:")
        for i, file in enumerate(demo_files, 1):
            print(f"   {i}. {file.name} ({file.stat().st_size/1024:.1f} KB)")

        # Create enhanced workflow
        workflow = create_enhanced_rag_workflow()

        # Execute enhanced workflow
        initial_state = {
            "file_paths": [str(f) for f in demo_files],
            "current_file_index": 0,
            "processed_filenames": [],
            "current_filename": "",
            "task_results": {},
            "workflow_status": "initialized",
            "error": None
        }

        print("\n🔄 EXECUTING ENHANCED RAG WORKFLOW...")
        final_state = workflow.invoke(initial_state)

        # Display enhanced results
        print("\n📊 ENHANCED FINAL RESULTS:")
        print("="*70)

        task_results = final_state.get("task_results", {})

        for filename, results in task_results.items():
            print(f"\n📄 DOCUMENT: {filename}")
            print("─" * 50)

            if "error" in results:
                print(f"❌ Error: {results['error']}")
                continue

            for task_name, task_result in results.items():
                if isinstance(task_result, dict):
                    response = task_result.get("response", "No response")
                    # confidence = task_result.get("confidence", 0)
                    chunks_used = task_result.get("context_chunks_used", 0)
                    token_usage = task_result.get("token_usage", {})

                    print(f"\n🎯 {task_name.upper()}:")
                    print(f"   📝 Response: {response[:150]}...")
                    # print(f"   🎯 Confidence: {confidence:.3f}")
                    print(f"   📚 Chunks used: {chunks_used}")
                    if token_usage:
                        print(f"   🔢 Tokens: {token_usage.get('total_tokens', 0)}")

        # Display vector store statistics
        print("\n📊 ENHANCED VECTOR STORE STATISTICS:")
        print("="*50)
        stats = processor.vector_store.get_collection_stats()

        # Final token usage summary
        token_tracker.print_summary()

    else:
        print("📂 No sample documents found for demonstration")

else:
    print("📂 No input directory found")


🚀 ENHANCED RAG DEMONSTRATION STARTING
📊 Features: Filename IDs + Token Tracking + Detailed Progress
🚀 INITIALIZING RAG EXPENSE PROCESSOR
   🤖 Text Model: gemma3:1b
🗄️ INITIALIZING VECTOR STORE
   🤖 Embedding Model: nomic-embed-text


KeyboardInterrupt: 

In [None]:
# ================================
# STEP 10: ENHANCED INTERACTIVE FUNCTIONS
# ================================

def demo_document_isolation():
    """Enhanced demonstration of document isolation"""
    print("\n🧪 ENHANCED DOCUMENT ISOLATION DEMONSTRATION")
    print("="*60)

    stats = processor.vector_store.get_collection_stats()
    filenames = stats['filenames']

    if len(filenames) >= 2:
        filename_1, filename_2 = filenames[0], filenames[1]

        test_query = "total amount"
        print(f"🔍 Testing query '{test_query}' across isolated documents:")

        # Query document 1
        print(f"\n📄 RESULTS FOR: {filename_1}")
        result_1 = processor.vector_store.query_document_specific(test_query, filename_1)

        # Query document 2
        print(f"\n📄 RESULTS FOR: {filename_2}")
        result_2 = processor.vector_store.query_document_specific(test_query, filename_2)

        print(f"\n✅ ISOLATION VERIFIED: Each query only returns chunks from its specific document")

    else:
        print("⚠️ Need at least 2 documents to demonstrate isolation")

def query_specific_document(filename: str, task: str):
    """Query a specific document with enhanced tracking"""

    print(f"\n❓ QUERYING DOCUMENT: {filename}")
    print(f"🎯 TASK: {task}")
    print("─" * 50)

    result = processor.process_expense_task(filename, task)

    print(f"\n📝 RESPONSE: {result.get('response', 'No response')}")
    # print(f"🎯 CONFIDENCE: {result.get('confidence', 0):.3f}")
    print(f"📚 CHUNKS USED: {result.get('context_chunks_used', 0)}")

    token_usage = result.get('token_usage', {})
    if token_usage:
        print(f"🔢 TOKEN USAGE:")
        print(f"   📥 Input: {token_usage.get('input_tokens', 0)}")
        print(f"   📤 Output: {token_usage.get('output_tokens', 0)}")
        print(f"   🎯 Total: {token_usage.get('total_tokens', 0)}")

print("\n🎯 ENHANCED RAG SYSTEM READY!")
print("="*70)

print("""
✅ ENHANCED SYSTEM COMPONENTS:
- 📄 Filename-based Document Management
- 🔍 Enhanced OCR Processor (detailed progress tracking)
- 🗄️ Enhanced Vector Store (ChromaDB + isolation)
- 🤖 Token Usage Tracker (comprehensive monitoring)
- 🎯 Enhanced RAG Processor (detailed inference tracking)
- 📊 Enhanced LangGraph Workflow (step-by-step progress)

📋 DEMONSTRATION COMPLETED:
- ✅ Document ingestion with filename-based IDs
- ✅ Task-based information extraction with progress tracking
- ✅ Context optimization with detailed metrics
- ✅ Cross-document contamination prevention verified
- ✅ Token usage monitoring for all LLM calls

🎯 CONTEXT OPTIMIZATION FEATURES:
- ✅ Chunk size optimization (500 chars max per chunk)
- ✅ Relevant chunk filtering (top 3 per query with relevance scores)
- ✅ Context length limits (1500 chars max to prevent overload)
- ✅ Task-specific query optimization (tailored search terms)
- ✅ Token usage tracking (input/output/total for every LLM call)
""")

# Show available documents for interaction
stats = processor.vector_store.get_collection_stats()
if stats['filenames']:
    print(f"\n📄 AVAILABLE DOCUMENTS FOR QUERYING:")
    for i, filename in enumerate(stats['filenames'], 1):
        print(f"   {i}. {filename}")

    # Demo enhanced isolation
    demo_document_isolation()

    # Show available tasks
    tasks = processor.task_manager.list_available_tasks()
    print(f"\n🎯 AVAILABLE TASKS:")
    for i, task in enumerate(tasks, 1):
        print(f"   {i}. {task}")

else:
    print("\n⚠️ No documents available. Process some documents first.")

print(f"\n🔢 FINAL TOKEN USAGE SUMMARY:")
token_tracker.print_summary()

In [None]:
!pip install llama-index

# Task
Explain the provided Python code and add a front end using a web framework (Flask, Streamlit, or Gradio) to accept expense claim documents as file uploads, process them using the existing RAG workflow, and display the results.

## Choose a web framework

### Subtask:
Choose a suitable Python web framework (e.g., Flask, Streamlit, or Gradio) to build the front end.


## Design the user interface

### Subtask:
Create a simple web page with a file upload form for users to submit their expense claim documents.


**Reasoning**:
Create a Gradio interface with file upload and output components as requested by the instructions.



In [None]:
import gradio as gr

def process_document_ui(file):
    """Placeholder function for processing the uploaded file."""
    if file is None:
        return "Please upload a file.", "", ""

    # Simulate processing
    filename = file.name
    status = f"Processing file: {filename}"
    extracted_details = f"Simulated details for {filename}"

    return status, extracted_details, ""

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced RAG-Based Expense Claims Processing")
    gr.Markdown("Upload your expense claim document to extract details.")

    file_input = gr.File(label="Upload Expense Document")
    process_button = gr.Button("Process Document")

    status_output = gr.Textbox(label="Processing Status")
    details_output = gr.Textbox(label="Extracted Details")
    # Add a placeholder for potential future output like a dataframe or summary
    summary_output = gr.Dataframe(label="Summary", visible=False)

    process_button.click(
        process_document_ui,
        inputs=file_input,
        outputs=[status_output, details_output, summary_output]
    )

demo.launch()

## Implement file upload handling

### Subtask:
Write code to receive the uploaded files on the server side and save them temporarily.


**Reasoning**:
Modify the `process_document_ui` function to access the temporary file path provided by Gradio when a file is uploaded and return a status confirming receipt of the file path.



In [None]:
import gradio as gr

def process_document_ui(file):
    """Function to receive the uploaded file and get its temporary path."""
    if file is None:
        return "Please upload a file.", "", ""

    # Access the temporary file path provided by Gradio
    file_path = file.name

    # Confirm receipt of the file path
    status = f"Received file: {file_path}. Ready for processing."
    extracted_details = "" # Keep as placeholder for now
    summary = "" # Keep as placeholder for now

    return status, extracted_details, summary

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced RAG-Based Expense Claims Processing")
    gr.Markdown("Upload your expense claim document to extract details.")

    file_input = gr.File(label="Upload Expense Document")
    process_button = gr.Button("Process Document")

    status_output = gr.Textbox(label="Processing Status")
    details_output = gr.Textbox(label="Extracted Details")
    # Add a placeholder for potential future output like a dataframe or summary
    summary_output = gr.Dataframe(label="Summary", visible=False)

    process_button.click(
        process_document_ui,
        inputs=file_input,
        outputs=[status_output, details_output, summary_output]
    )

demo.launch()

## Integrate with the rag workflow

### Subtask:
Modify the existing RAG workflow to accept file paths from the front end as input.


**Reasoning**:
I need to modify the existing RAG workflow to accept file paths from the front end. This involves updating the `create_enhanced_rag_workflow` function and specifically the `enhanced_ingestion_node` to take a list of file paths as input, rather than relying on a hardcoded directory. I will also update the `process_document_ui` function in the Gradio app to pass the temporary file path to the workflow.



In [None]:
from langgraph.graph import StateGraph
from typing import TypedDict, List, Dict, Any, Optional, Tuple
from pathlib import Path
from datetime import datetime
import json
import pandas as pd
import gradio as gr

# Assume EnhancedRAGExpenseProcessor, TokenUsageTracker, EnhancedExpenseTaskManager,
# FilenameBasedDocumentManager, EnhancedOCRProcessor, EnhancedIsolatedVectorStore
# are defined in previous cells and available in the environment.
# Also assume ollama is imported and configured.

class EnhancedRAGWorkflowState(TypedDict):
    """Enhanced state for RAG workflow"""
    file_paths: List[str] # Modified to accept a list of file paths
    current_file_index: int
    processed_filenames: List[str]
    current_filename: str
    task_results: Dict[str, Dict[str, Any]]
    workflow_status: str
    error: Optional[str]

def create_enhanced_rag_workflow(processor: EnhancedRAGExpenseProcessor) -> StateGraph:
    """Create enhanced LangGraph workflow for RAG processing"""

    def enhanced_ingestion_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced ingestion with detailed tracking"""

        print("\n" + "🔄 WORKFLOW: ENHANCED INGESTION PHASE STARTING")
        print("="*70)

        # Get file paths from the state
        file_paths = state.get("file_paths", [])
        processed_filenames = []

        if not file_paths:
            state["workflow_status"] = "ingestion_failed"
            state["error"] = "No file paths provided for ingestion."
            print("❌ INGESTION FAILED: No file paths provided.")
            return state


        for i, file_path in enumerate(file_paths, 1):
            print(f"\n[{i}/{len(file_paths)}] Processing file: {Path(file_path).name}")

            try:
                # Use the processor to ingest the document
                filename = processor.ingest_document(file_path)
                if filename:
                    processed_filenames.append(filename)
                    print(f"✅ Successfully ingested: {filename}")
                else:
                    print(f"❌ Failed to ingest: {Path(file_path).name}")

            except Exception as e:
                print(f"❌ Error ingesting {Path(file_path).name}: {e}")
                state["error"] = str(e) # Store the error in state

        state["processed_filenames"] = processed_filenames
        state["workflow_status"] = "ingestion_complete" if processed_filenames else "ingestion_failed"

        print(f"\n📊 INGESTION PHASE COMPLETED")
        print(f"   ✅ Successfully processed: {len(processed_filenames)} files")
        print(f"   ❌ Failed: {len(file_paths) - len(processed_filenames)} files")

        return state

    def enhanced_task_processing_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced task processing with detailed tracking"""

        print("\n🎯 WORKFLOW: ENHANCED TASK PROCESSING PHASE STARTING")
        print("="*70)

        processed_filenames = state.get("processed_filenames", [])
        task_results = state.get("task_results", {}) # Initialize or get existing results

        if not processed_filenames:
            state["workflow_status"] = "processing_skipped"
            print("⚠️ TASK PROCESSING SKIPPED: No documents successfully ingested.")
            return state


        for i, filename in enumerate(processed_filenames, 1):
            print(f"\n[{i}/{len(processed_filenames)}] Processing tasks for: {filename}")

            try:
                results = processor.process_all_tasks_for_document(filename)
                task_results[filename] = results
                print(f"✅ Completed all tasks for: {filename}")

            except Exception as e:
                print(f"❌ Error processing tasks for {filename}: {e}")
                task_results[filename] = {"error": str(e)}
                state["error"] = str(e) # Store the error in state


        state["task_results"] = task_results
        state["workflow_status"] = "processing_complete"

        print(f"\n📊 TASK PROCESSING PHASE COMPLETED")
        print(f"   📄 Documents processed: {len(task_results)}")

        return state

    def enhanced_results_compilation_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced results compilation with detailed stats and CSV export"""

        print("\n📊 WORKFLOW: ENHANCED RESULTS COMPILATION STARTING")
        print("="*70)

        task_results = state.get("task_results", {})
        # token_tracker is a global instance assumed to be available
        # processor instance (and its vector_store) is also assumed to be available

        if not task_results:
             state["workflow_status"] = "compilation_skipped"
             print("⚠️ RESULTS COMPILATION SKIPPED: No task results to compile.")
             token_tracker.print_summary() # Print summary even if compilation skipped
             return state

        # Compile detailed statistics
        total_documents = len(task_results)
        successful_documents = sum(1 for r in task_results.values() if "error" not in r)

        # Save results with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = f"enhanced_rag_expense_results_{timestamp}.json"
        csv_file = f"enhanced_rag_expense_results_{timestamp}.csv"

        # Create comprehensive results package
        comprehensive_results = {
            "timestamp": timestamp,
            "summary": {
                "total_documents": total_documents,
                "successful_documents": successful_documents,
                "failed_documents": total_documents - successful_documents
            },
            "token_usage_summary": {
                "total_calls": token_tracker.call_count,
                "total_input_tokens": token_tracker.total_input_tokens,
                "total_output_tokens": token_tracker.total_output_tokens,
                "total_tokens": token_tracker.total_tokens
            },
            "document_results": task_results,
            "token_call_history": token_tracker.call_history
        }

        # Save JSON results
        try:
            with open(results_file, 'w') as f:
                json.dump(comprehensive_results, f, indent=2, default=str)
            print(f"💾 JSON RESULTS SAVED TO: {results_file}")
        except Exception as e:
            print(f"❌ Error saving JSON results: {e}")
            state["error"] = f"Error saving JSON results: {e}"


        # Create CSV from results
        csv_rows = []

        for filename, doc_results in task_results.items():
            if "error" in doc_results:
                # Add error row
                csv_rows.append({
                    "filename": filename,
                    "task": "workflow_error", # Indicate workflow level error for this document
                    "response": doc_results["error"],
                    "context_chunks_used": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "total_tokens": 0
                })
            else:
                # Process each task for this document
                for task_name, task_result in doc_results.items():
                    if isinstance(task_result, dict):
                        token_usage = task_result.get("token_usage", {})
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": task_result.get("response", ""),
                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                            "input_tokens": token_usage.get("input_tokens", 0),
                            "output_tokens": token_usage.get("output_tokens", 0),
                            "total_tokens": token_usage.get("total_tokens", 0)
                        })
                    else:
                         # Handle task-specific errors
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": f"Task error: {task_result}",
                            "context_chunks_used": 0,
                            "input_tokens": 0,
                            "output_tokens": 0,
                            "total_tokens": 0
                        })


        # Save CSV
        if csv_rows:
            try:
                df = pd.DataFrame(csv_rows)

                # Reorder columns for better readability - handle missing columns gracefully
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in df.columns]
                df = df[existing_columns]


                # Save to CSV
                df.to_csv(csv_file, index=False, encoding='utf-8')
                print(f"💾 CSV RESULTS SAVED TO: {csv_file}")

                # Display summary statistics from CSV
                print(f"\n📊 CSV Summary:")
                print(f"   📄 Total rows: {len(df)}")
                # Ensure 'filename' column exists before calling nunique
                if 'filename' in df.columns:
                    print(f"   📁 Documents: {df['filename'].nunique()}")
                     # Handle case where no tasks were processed successfully
                    if not df[df['task'] != 'workflow_error'].empty and 'filename' in df.columns:
                         print(f"   🎯 Tasks per document: {df[df['task'] != 'workflow_error'].groupby('filename').size().mean():.1f}")
                    else:
                        print("   🎯 Tasks per document: N/A (No successful tasks)")

                # Ensure 'total_tokens' column exists and is numeric before summing
                if 'total_tokens' in df.columns:
                    try:
                        df['total_tokens'] = pd.to_numeric(df['total_tokens'], errors='coerce').fillna(0)
                        print(f"   🔢 Total tokens used: {df['total_tokens'].sum():,}")
                    except Exception as e:
                         print(f"⚠️ Could not calculate total tokens from CSV: {e}")
                else:
                    print("⚠️ 'total_tokens' column not found in CSV.")

            except Exception as e:
                print(f"❌ Error processing or saving CSV results: {e}")
                state["error"] = f"Error processing or saving CSV results: {e}"


        # Also save a summary CSV with aggregated data per document
        summary_csv_file = f"enhanced_rag_expense_summary_{timestamp}.csv"
        summary_rows = []

        for filename, doc_results in task_results.items():
            if "error" not in doc_results:
                row = {"filename": filename}

                # Extract key information from each task
                for task_name in ["extract_amount", "extract_date", "extract_vendor",
                                "extract_category", "extract_tax"]:
                    if task_name in doc_results and isinstance(doc_results[task_name], dict):
                        response = doc_results[task_name].get("response", "")
                        # Clean the response (take first line or first 100 chars)
                        cleaned = response.split('\n')[0][:100] if response else ""
                        row[task_name] = cleaned
                    else:
                        row[task_name] = f"Task {task_name} failed or not found"


                # Add token totals
                total_tokens = sum(
                    doc_results.get(task, {}).get("token_usage", {}).get("total_tokens", 0)
                    for task in doc_results if isinstance(doc_results.get(task), dict)
                )
                row["total_tokens_used"] = total_tokens

                summary_rows.append(row)
            else:
                 summary_rows.append({"filename": filename, "workflow_error": doc_results["error"]}) # Add error to summary


        if summary_rows:
            try:
                summary_df = pd.DataFrame(summary_rows)
                summary_df.to_csv(summary_csv_file, index=False, encoding='utf-8')
                print(f"💾 SUMMARY CSV SAVED TO: {summary_csv_file}")
            except Exception as e:
                print(f"❌ Error saving summary CSV: {e}")
                state["error"] = f"Error saving summary CSV: {e}"


        print(f"\n📊 FILES SAVED:")
        print(f"   📄 Detailed JSON: {results_file}")
        print(f"   📄 Detailed CSV: {csv_file}")
        print(f"   📄 Summary CSV: {summary_csv_file}")
        print(f"   📄 Documents processed: {successful_documents}/{total_documents}")

        # Print token usage summary
        token_tracker.print_summary()

        state["workflow_status"] = "complete"
        return state


    # Build enhanced workflow
    workflow = StateGraph(EnhancedRAGWorkflowState)

    workflow.add_node("enhanced_ingestion", enhanced_ingestion_node)
    workflow.add_node("enhanced_task_processing", enhanced_task_processing_node)
    workflow.add_node("enhanced_results_compilation", enhanced_results_compilation_node)

    workflow.set_entry_point("enhanced_ingestion")
    workflow.add_edge("enhanced_ingestion", "enhanced_task_processing")
    workflow.add_edge("enhanced_task_processing", "enhanced_results_compilation")
    workflow.set_finish_point("enhanced_results_compilation")

    return workflow.compile()

# Update the Gradio UI function to call the workflow
def process_document_ui(file):
    """Function to receive the uploaded file, run the workflow, and display results."""
    if file is None:
        return "Please upload a file.", "No file uploaded.", pd.DataFrame() # Return empty DataFrame

    file_path = file.name # Gradio provides the temporary path here

    status = f"Received file: {file_path}. Starting RAG workflow..."
    print(status) # Print status to console

    # Initialize the RAG processor and workflow
    # Assuming processor and token_tracker are accessible globally or initialized here
    try:
        processor = EnhancedRAGExpenseProcessor()
        workflow = create_enhanced_rag_workflow(processor)
    except Exception as e:
        error_msg = f"Error initializing RAG components: {e}"
        print(f"❌ {error_msg}")
        return error_msg, "Initialization failed.", pd.DataFrame()


    # Execute enhanced workflow with the uploaded file path
    initial_state = {
        "file_paths": [file_path], # Pass the uploaded file path as a list
        "current_file_index": 0,
        "processed_filenames": [],
        "current_filename": "",
        "task_results": {},
        "workflow_status": "initialized",
        "error": None
    }

    try:
        final_state = workflow.invoke(initial_state)
        status = f"Workflow completed with status: {final_state.get('workflow_status')}"
        print(status) # Print final status

        # Process final state to display results
        task_results = final_state.get("task_results", {})
        if final_state.get("error"):
             details = f"Workflow Error: {final_state['error']}"
             # Attempt to display any partial results if available
             if task_results:
                 details += "\n\nPartial Results:"
                 # Create a simple string representation of partial results
                 for filename, doc_results in task_results.items():
                      details += f"\nDocument: {filename}"
                      if "error" in doc_results:
                           details += f"\n  Error: {doc_results['error']}"
                      else:
                          for task_name, task_result in doc_results.items():
                              if isinstance(task_result, dict):
                                   details += f"\n  {task_name}: {task_result.get('response', 'No response')[:100]}..."
                              else:
                                  details += f"\n  {task_name}: Error - {task_result}"

             return status, details, pd.DataFrame() # Return empty dataframe on error

        # Compile results for display
        if task_results:
            compiled_details = ""
            csv_rows = [] # Prepare data for DataFrame display

            for filename, results in task_results.items():
                 compiled_details += f"\n📄 DOCUMENT: {filename}\n" + "─" * 50 + "\n"
                 if "error" in results:
                     compiled_details += f"❌ Error: {results['error']}\n"
                      # Add error row to CSV data
                     csv_rows.append({
                        "filename": filename,
                        "task": "workflow_error",
                        "response": results["error"],
                        "context_chunks_used": 0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "total_tokens": 0
                    })
                 else:
                    for task_name, task_result in results.items():
                        if isinstance(task_result, dict):
                            response = task_result.get("response", "No response")
                            chunks_used = task_result.get("context_chunks_used", 0)
                            token_usage = task_result.get("token_usage", {})

                            compiled_details += f"\n🎯 {task_name.upper()}:\n"
                            compiled_details += f"   📝 Response: {response[:200]}...\n" # Limit display length
                            compiled_details += f"   📚 Chunks used: {chunks_used}\n"
                            if token_usage:
                                compiled_details += f"   🔢 Tokens: {token_usage.get('total_tokens', 0)}\n"

                             # Add task result to CSV data
                            csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": response,
                                "context_chunks_used": chunks_used,
                                "input_tokens": token_usage.get("input_tokens", 0),
                                "output_tokens": token_usage.get("output_tokens", 0),
                                "total_tokens": token_usage.get("total_tokens", 0)
                            })
                        else:
                            # Handle task-specific errors in display and CSV
                             compiled_details += f"\n🎯 {task_name.upper()}:\n"
                             compiled_details += f"   ❌ Task Error: {task_result}\n"
                             csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": f"Task error: {task_result}",
                                "context_chunks_used": 0,
                                "input_tokens": 0,
                                "output_tokens": 0,
                                "total_tokens": 0
                            })


            # Create DataFrame for the summary output
            if csv_rows:
                results_df = pd.DataFrame(csv_rows)
                 # Ensure column order
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in results_df.columns]
                results_df = results_df[existing_columns]

                # Convert token columns to numeric, handling errors
                for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                     if token_col in results_df.columns:
                        results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)


            else:
                 results_df = pd.DataFrame({"Status": ["No task results generated."]}) # Empty DataFrame or error indicator


            return status, compiled_details, results_df # Return DataFrame for display

        else:
            return status, "No task results generated.", pd.DataFrame() # Return empty DataFrame if no task results

    except Exception as e:
        error_msg = f"Workflow execution failed: {e}"
        print(f"❌ {error_msg}")
        return status, error_msg, pd.DataFrame() # Return empty DataFrame on workflow execution error

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced RAG-Based Expense Claims Processing")
    gr.Markdown("Upload your expense claim document (PDF, JPG, PNG, etc.) to extract details.")

    file_input = gr.File(label="Upload Expense Document", file_types=[".pdf", ".jpg", ".jpeg", ".png", ".tiff"])
    process_button = gr.Button("Process Document")

    status_output = gr.Textbox(label="Processing Status", max_lines=5)
    details_output = gr.Textbox(label="Extracted Details (Raw)", interactive=False, lines=20)
    summary_output = gr.Dataframe(label="Extracted Details (Structured)") # Use Dataframe for structured output

    process_button.click(
        process_document_ui,
        inputs=file_input,
        outputs=[status_output, details_output, summary_output]
    )

# Note: demo.launch() will be called separately when running the Gradio app
# For this subtask, we define the function but don't call launch() yet.

**Reasoning**:
The previous code block failed because the `langgraph` library was not found. I need to install this library before I can use `StateGraph`.



In [None]:
!pip install langgraph

**Reasoning**:
I have successfully installed `langgraph`. Now I can re-run the code to modify the RAG workflow and the Gradio UI function to handle the uploaded file path.



In [None]:
from langgraph.graph import StateGraph
from typing import TypedDict, List, Dict, Any, Optional, Tuple
from pathlib import Path
from datetime import datetime
import json
import pandas as pd
import gradio as gr

# Assume EnhancedRAGExpenseProcessor, TokenUsageTracker, EnhancedExpenseTaskManager,
# FilenameBasedDocumentManager, EnhancedOCRProcessor, EnhancedIsolatedVectorStore
# are defined in previous cells and available in the environment.
# Also assume ollama is imported and configured.

class EnhancedRAGWorkflowState(TypedDict):
    """Enhanced state for RAG workflow"""
    file_paths: List[str] # Modified to accept a list of file paths
    current_file_index: int
    processed_filenames: List[str]
    current_filename: str
    task_results: Dict[str, Dict[str, Any]]
    workflow_status: str
    error: Optional[str]

def create_enhanced_rag_workflow(processor: EnhancedRAGExpenseProcessor) -> StateGraph:
    """Create enhanced LangGraph workflow for RAG processing"""

    def enhanced_ingestion_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced ingestion with detailed tracking"""

        print("\n" + "🔄 WORKFLOW: ENHANCED INGESTION PHASE STARTING")
        print("="*70)

        # Get file paths from the state
        file_paths = state.get("file_paths", [])
        processed_filenames = []

        if not file_paths:
            state["workflow_status"] = "ingestion_failed"
            state["error"] = "No file paths provided for ingestion."
            print("❌ INGESTION FAILED: No file paths provided.")
            return state


        for i, file_path in enumerate(file_paths, 1):
            print(f"\n[{i}/{len(file_paths)}] Processing file: {Path(file_path).name}")

            try:
                # Use the processor to ingest the document
                filename = processor.ingest_document(file_path)
                if filename:
                    processed_filenames.append(filename)
                    print(f"✅ Successfully ingested: {filename}")
                else:
                    print(f"❌ Failed to ingest: {Path(file_path).name}")

            except Exception as e:
                print(f"❌ Error ingesting {Path(file_path).name}: {e}")
                state["error"] = str(e) # Store the error in state

        state["processed_filenames"] = processed_filenames
        state["workflow_status"] = "ingestion_complete" if processed_filenames else "ingestion_failed"

        print(f"\n📊 INGESTION PHASE COMPLETED")
        print(f"   ✅ Successfully processed: {len(processed_filenames)} files")
        print(f"   ❌ Failed: {len(file_paths) - len(processed_filenames)} files")

        return state

    def enhanced_task_processing_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced task processing with detailed tracking"""

        print("\n🎯 WORKFLOW: ENHANCED TASK PROCESSING PHASE STARTING")
        print("="*70)

        processed_filenames = state.get("processed_filenames", [])
        task_results = state.get("task_results", {}) # Initialize or get existing results

        if not processed_filenames:
            state["workflow_status"] = "processing_skipped"
            print("⚠️ TASK PROCESSING SKIPPED: No documents successfully ingested.")
            return state


        for i, filename in enumerate(processed_filenames, 1):
            print(f"\n[{i}/{len(processed_filenames)}] Processing tasks for: {filename}")

            try:
                results = processor.process_all_tasks_for_document(filename)
                task_results[filename] = results
                print(f"✅ Completed all tasks for: {filename}")

            except Exception as e:
                print(f"❌ Error processing tasks for {filename}: {e}")
                task_results[filename] = {"error": str(e)}
                state["error"] = str(e) # Store the error in state


        state["task_results"] = task_results
        state["workflow_status"] = "processing_complete"

        print(f"\n📊 TASK PROCESSING PHASE COMPLETED")
        print(f"   📄 Documents processed: {len(task_results)}")

        return state

    def enhanced_results_compilation_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced results compilation with detailed stats and CSV export"""

        print("\n📊 WORKFLOW: ENHANCED RESULTS COMPILATION STARTING")
        print("="*70)

        task_results = state.get("task_results", {})
        # token_tracker is a global instance assumed to be available
        # processor instance (and its vector_store) is also assumed to be available

        if not task_results:
             state["workflow_status"] = "compilation_skipped"
             print("⚠️ RESULTS COMPILATION SKIPPED: No task results to compile.")
             token_tracker.print_summary() # Print summary even if compilation skipped
             return state


        # Compile detailed statistics
        total_documents = len(task_results)
        successful_documents = sum(1 for r in task_results.values() if "error" not in r)

        # Save results with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = f"enhanced_rag_expense_results_{timestamp}.json"
        csv_file = f"enhanced_rag_expense_results_{timestamp}.csv"

        # Create comprehensive results package
        comprehensive_results = {
            "timestamp": timestamp,
            "summary": {
                "total_documents": total_documents,
                "successful_documents": successful_documents,
                "failed_documents": total_documents - successful_documents
            },
            "token_usage_summary": {
                "total_calls": token_tracker.call_count,
                "total_input_tokens": token_tracker.total_input_tokens,
                "total_output_tokens": token_output_tokens,
                "total_tokens": token_tracker.total_tokens
            },
            "document_results": task_results,
            "token_call_history": token_tracker.call_history
        }

        # Save JSON results
        try:
            with open(results_file, 'w') as f:
                json.dump(comprehensive_results, f, indent=2, default=str)
            print(f"💾 JSON RESULTS SAVED TO: {results_file}")
        except Exception as e:
            print(f"❌ Error saving JSON results: {e}")
            state["error"] = f"Error saving JSON results: {e}"


        # Create CSV from results
        csv_rows = []

        for filename, doc_results in task_results.items():
            if "error" in doc_results:
                # Add error row
                csv_rows.append({
                    "filename": filename,
                    "task": "workflow_error", # Indicate workflow level error for this document
                    "response": doc_results["error"],
                    "context_chunks_used": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "total_tokens": 0
                })
            else:
                # Process each task for this document
                for task_name, task_result in doc_results.items():
                    if isinstance(task_result, dict):
                        token_usage = task_result.get("token_usage", {})
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": task_result.get("response", ""),
                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                            "input_tokens": token_usage.get("input_tokens", 0),
                            "output_tokens": token_usage.get("output_tokens", 0),
                            "total_tokens": token_usage.get("total_tokens", 0)
                        })
                    else:
                         # Handle task-specific errors
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": f"Task error: {task_result}",
                            "context_chunks_used": 0,
                            "input_tokens": 0,
                            "output_tokens": 0,
                            "total_tokens": 0
                        })


        # Save CSV
        if csv_rows:
            try:
                df = pd.DataFrame(csv_rows)

                # Reorder columns for better readability - handle missing columns gracefully
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in df.columns]
                df = df[existing_columns]


                # Save to CSV
                df.to_csv(csv_file, index=False, encoding='utf-8')
                print(f"💾 CSV RESULTS SAVED TO: {csv_file}")

                # Display summary statistics from CSV
                print(f"\n📊 CSV Summary:")
                print(f"   📄 Total rows: {len(df)}")
                # Ensure 'filename' column exists before calling nunique
                if 'filename' in df.columns:
                    print(f"   📁 Documents: {df['filename'].nunique()}")
                     # Handle case where no tasks were processed successfully
                    if not df[df['task'] != 'workflow_error'].empty and 'filename' in df.columns:
                         print(f"   🎯 Tasks per document: {df[df['task'] != 'workflow_error'].groupby('filename').size().mean():.1f}")
                    else:
                        print("   🎯 Tasks per document: N/A (No successful tasks)")

                # Ensure 'total_tokens' column exists and is numeric before summing
                if 'total_tokens' in df.columns:
                    try:
                        df['total_tokens'] = pd.to_numeric(df['total_tokens'], errors='coerce').fillna(0).astype(int)
                        print(f"   🔢 Total tokens used: {df['total_tokens'].sum():,}")
                    except Exception as e:
                         print(f"⚠️ Could not calculate total tokens from CSV: {e}")
                else:
                    print("⚠️ 'total_tokens' column not found in CSV.")

            except Exception as e:
                print(f"❌ Error processing or saving CSV results: {e}")
                state["error"] = f"Error processing or saving CSV results: {e}"


        # Also save a summary CSV with aggregated data per document
        summary_csv_file = f"enhanced_rag_expense_summary_{timestamp}.csv"
        summary_rows = []

        for filename, doc_results in task_results.items():
            if "error" not in doc_results:
                row = {"filename": filename}

                # Extract key information from each task
                for task_name in ["extract_amount", "extract_date", "extract_vendor",
                                "extract_category", "extract_tax"]:
                    if task_name in doc_results and isinstance(doc_results[task_name], dict):
                        response = doc_results[task_name].get("response", "")
                        # Clean the response (take first line or first 100 chars)
                        cleaned = response.split('\n')[0][:100] if response else ""
                        row[task_name] = cleaned
                    else:
                        row[task_name] = f"Task {task_name} failed or not found"


                # Add token totals
                total_tokens = sum(
                    doc_results.get(task, {}).get("token_usage", {}).get("total_tokens", 0)
                    for task in doc_results if isinstance(doc_results.get(task), dict)
                )
                row["total_tokens_used"] = total_tokens

                summary_rows.append(row)
            else:
                 summary_rows.append({"filename": filename, "workflow_error": doc_results["error"]}) # Add error to summary


        if summary_rows:
            try:
                summary_df = pd.DataFrame(summary_rows)
                summary_df.to_csv(summary_csv_file, index=False, encoding='utf-8')
                print(f"💾 SUMMARY CSV SAVED TO: {summary_csv_file}")
            except Exception as e:
                print(f"❌ Error saving summary CSV: {e}")
                state["error"] = f"Error saving summary CSV: {e}"


        print(f"\n📊 FILES SAVED:")
        print(f"   📄 Detailed JSON: {results_file}")
        print(f"   📄 Detailed CSV: {csv_file}")
        print(f"   📄 Summary CSV: {summary_csv_file}")
        print(f"   📄 Documents processed: {successful_documents}/{total_documents}")

        # Print token usage summary
        token_tracker.print_summary()

        state["workflow_status"] = "complete"
        return state


    # Build enhanced workflow
    workflow = StateGraph(EnhancedRAGWorkflowState)

    workflow.add_node("enhanced_ingestion", enhanced_ingestion_node)
    workflow.add_node("enhanced_task_processing", enhanced_task_processing_node)
    workflow.add_node("enhanced_results_compilation", enhanced_results_compilation_node)

    workflow.set_entry_point("enhanced_ingestion")
    workflow.add_edge("enhanced_ingestion", "enhanced_task_processing")
    workflow.add_edge("enhanced_task_processing", "enhanced_results_compilation")
    workflow.set_finish_point("enhanced_results_compilation")

    return workflow.compile()

# Update the Gradio UI function to call the workflow
def process_document_ui(file):
    """Function to receive the uploaded file, run the workflow, and display results."""
    if file is None:
        return "Please upload a file.", "No file uploaded.", pd.DataFrame() # Return empty DataFrame

    file_path = file.name # Gradio provides the temporary path here

    status = f"Received file: {file_path}. Starting RAG workflow..."
    print(status) # Print status to console

    # Initialize the RAG processor and workflow
    # Assuming processor and token_tracker are accessible globally or initialized here
    try:
        processor = EnhancedRAGExpenseProcessor()
        workflow = create_enhanced_rag_workflow(processor)
    except Exception as e:
        error_msg = f"Error initializing RAG components: {e}"
        print(f"❌ {error_msg}")
        return error_msg, "Initialization failed.", pd.DataFrame()


    # Execute enhanced workflow with the uploaded file path
    initial_state = {
        "file_paths": [file_path], # Pass the uploaded file path as a list
        "current_file_index": 0,
        "processed_filenames": [],
        "current_filename": "",
        "task_results": {},
        "workflow_status": "initialized",
        "error": None
    }

    try:
        final_state = workflow.invoke(initial_state)
        status = f"Workflow completed with status: {final_state.get('workflow_status')}"
        print(status) # Print final status

        # Process final state to display results
        task_results = final_state.get("task_results", {})
        if final_state.get("error"):
             details = f"Workflow Error: {final_state['error']}"
             # Attempt to display any partial results if available
             if task_results:
                 details += "\n\nPartial Results:"
                 # Create a simple string representation of partial results
                 for filename, doc_results in task_results.items():
                      details += f"\nDocument: {filename}"
                      if "error" in doc_results:
                           details += f"\n  Error: {doc_results['error']}"
                      else:
                          for task_name, task_result in doc_results.items():
                              if isinstance(task_result, dict):
                                   details += f"\n  {task_name}: {task_result.get('response', 'No response')[:100]}..."
                              else:
                                  details += f"\n  {task_name}: Error - {task_result}"

             return status, details, pd.DataFrame() # Return empty dataframe on error

        # Compile results for display
        if task_results:
            compiled_details = ""
            csv_rows = [] # Prepare data for DataFrame display

            for filename, results in task_results.items():
                 compiled_details += f"\n📄 DOCUMENT: {filename}\n" + "─" * 50 + "\n"
                 if "error" in results:
                     compiled_details += f"❌ Error: {results['error']}\n"
                      # Add error row to CSV data
                     csv_rows.append({
                        "filename": filename,
                        "task": "workflow_error",
                        "response": results["error"],
                        "context_chunks_used": 0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "total_tokens": 0
                    })
                 else:
                    for task_name, task_result in results.items():
                        if isinstance(task_result, dict):
                            response = task_result.get("response", "No response")
                            chunks_used = task_result.get("context_chunks_used", 0)
                            token_usage = task_result.get("token_usage", {})

                            compiled_details += f"\n🎯 {task_name.upper()}:\n"
                            compiled_details += f"   📝 Response: {response[:200]}...\n" # Limit display length
                            compiled_details += f"   📚 Chunks used: {chunks_used}\n"
                            if token_usage:
                                compiled_details += f"   🔢 Tokens: {token_usage.get('total_tokens', 0)}\n"

                             # Add task result to CSV data
                            csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": response,
                                "context_chunks_used": chunks_used,
                                "input_tokens": token_usage.get("input_tokens", 0),
                                "output_tokens": token_usage.get("output_tokens", 0),
                                "total_tokens": token_usage.get("total_tokens", 0)
                            })
                        else:
                            # Handle task-specific errors in display and CSV
                             compiled_details += f"\n🎯 {task_name.upper()}:\n"
                             compiled_details += f"   ❌ Task Error: {task_result}\n"
                             csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": f"Task error: {task_result}",
                                "context_chunks_used": 0,
                                "input_tokens": 0,
                                "output_tokens": 0,
                                "total_tokens": 0
                            })


            # Create DataFrame for the summary output
            if csv_rows:
                results_df = pd.DataFrame(csv_rows)
                 # Ensure column order
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in results_df.columns]
                results_df = results_df[existing_columns]

                # Convert token columns to numeric, handling errors
                for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                     if token_col in results_df.columns:
                        results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)


            else:
                 results_df = pd.DataFrame({"Status": ["No task results generated."]}) # Empty DataFrame or error indicator


            return status, compiled_details, results_df # Return DataFrame for display

        else:
            return status, "No task results generated.", pd.DataFrame() # Return empty DataFrame if no task results

    except Exception as e:
        error_msg = f"Workflow execution failed: {e}"
        print(f"❌ {error_msg}")
        return status, error_msg, pd.DataFrame() # Return empty DataFrame on workflow execution error

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced RAG-Based Expense Claims Processing")
    gr.Markdown("Upload your expense claim document (PDF, JPG, PNG, etc.) to extract details.")

    file_input = gr.File(label="Upload Expense Document", file_types=[".pdf", ".jpg", ".jpeg", ".png", ".tiff"])
    process_button = gr.Button("Process Document")

    status_output = gr.Textbox(label="Processing Status", max_lines=5)
    details_output = gr.Textbox(label="Extracted Details (Raw)", interactive=False, lines=20)
    summary_output = gr.Dataframe(label="Extracted Details (Structured)") # Use Dataframe for structured output

    process_button.click(
        process_document_ui,
        inputs=file_input,
        outputs=[status_output, details_output, summary_output]
    )

# Note: demo.launch() will be called separately when running the Gradio app
# For this subtask, we define the function but don't call launch() yet.

## Run the workflow

### Subtask:
Trigger the RAG workflow with the uploaded file.


**Reasoning**:
The previous execution failed because the necessary classes (`EnhancedRAGExpenseProcessor`, `TokenUsageTracker`, etc.) were not defined in the current environment. I need to re-execute the code cells that define these classes and initialize the `processor` and `token_tracker` instances before I can trigger the workflow in the Gradio function. I will also ensure the `create_enhanced_rag_workflow` function is defined, which it appears to be in the provided history, but I need to make sure the function is accessible. Finally, I will update the `process_document_ui` function to invoke the workflow with the uploaded file's path and process the results for display in the Gradio interface.



In [None]:
# Re-define or ensure availability of necessary classes and instances
# Assuming the previous cells defining these are re-executed

# Initialize the RAG processor and workflow globally for efficiency in Gradio
# This avoids re-initializing the processor (and thus ChromaDB client/collection)
# on every file upload in the Gradio app.

# Ensure the TokenUsageTracker is also a persistent instance
try:
    # Re-initialize token_tracker as it might have been reset by the environment
    token_tracker = TokenUsageTracker()
    processor = EnhancedRAGExpenseProcessor()
    # Create the workflow instance
    enhanced_rag_workflow = create_enhanced_rag_workflow(processor)
    print("\n✅ RAG Processor and Workflow initialized globally.")
except Exception as e:
    print(f"\n❌ Error initializing RAG components globally: {e}")
    # Handle this error appropriately if initialization fails

def process_document_ui(file):
    """Function to receive the uploaded file, run the workflow, and display results."""
    # Use the globally initialized workflow and processor
    global enhanced_rag_workflow, processor, token_tracker

    if file is None:
        return "Please upload a file.", "No file uploaded.", pd.DataFrame() # Return empty DataFrame

    file_path = file.name # Gradio provides the temporary path here

    status = f"Received file: {file_path}. Starting RAG workflow..."
    print(status) # Print status to console

    # Check if workflow initialization was successful
    if enhanced_rag_workflow is None:
         error_msg = "RAG Workflow failed to initialize. Cannot process file."
         print(f"❌ {error_msg}")
         return error_msg, "Initialization failed.", pd.DataFrame()


    # Execute enhanced workflow with the uploaded file path
    # Reset the token tracker for this specific workflow run if needed,
    # or let it accumulate total usage across all runs.
    # For this demo, let's let it accumulate, as it shows total usage.
    # If per-file usage is needed, reset here: token_tracker = TokenUsageTracker()


    initial_state = {
        "file_paths": [file_path], # Pass the uploaded file path as a list
        "current_file_index": 0, # Not strictly used with single file processing, but keep for state structure
        "processed_filenames": [],
        "current_filename": "", # Not strictly used with single file processing
        "task_results": {},
        "workflow_status": "initialized",
        "error": None
    }

    try:
        print(f"Executing workflow with state: {initial_state}")
        final_state = enhanced_rag_workflow.invoke(initial_state)
        workflow_final_status = final_state.get('workflow_status', 'unknown')
        status = f"Workflow completed with status: {workflow_final_status}"
        print(status) # Print final status
        print(f"Final state: {final_state}") # Print final state for debugging


        # Process final state to display results
        task_results = final_state.get("task_results", {})
        compiled_details = ""
        results_df = pd.DataFrame() # Default to empty DataFrame


        if final_state.get("error"):
             compiled_details = f"Workflow Error: {final_state['error']}\n\n"
             # Attempt to display any partial results if available
             if task_results:
                 compiled_details += "Partial Results:\n"
                 for filename, doc_results in task_results.items():
                      compiled_details += f"\nDocument: {filename}\n"
                      if "error" in doc_results:
                           compiled_details += f"  Error: {doc_results['error']}\n"
                      else:
                          for task_name, task_result in doc_results.items():
                              if isinstance(task_result, dict):
                                   response_preview = task_result.get('response', 'No response')
                                   compiled_details += f"  {task_name}: {response_preview[:100]}...\n"
                              else:
                                  compiled_details += f"  {task_name}: Error - {task_result}\n"

             # If there are partial results that can be put in a DataFrame, try that
             if task_results:
                  try:
                     csv_rows = []
                     for filename, doc_results in task_results.items():
                         if isinstance(doc_results, dict):
                            if "error" in doc_results:
                                csv_rows.append({
                                    "filename": filename,
                                    "task": "workflow_error",
                                    "response": doc_results["error"],
                                    "context_chunks_used": 0,
                                    "input_tokens": 0,
                                    "output_tokens": 0,
                                    "total_tokens": 0
                                })
                            else:
                                for task_name, task_result in doc_results.items():
                                    if isinstance(task_result, dict):
                                        token_usage = task_result.get("token_usage", {})
                                        csv_rows.append({
                                            "filename": filename,
                                            "task": task_name,
                                            "response": task_result.get("response", ""),
                                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                                            "input_tokens": token_usage.get("input_tokens", 0),
                                            "output_tokens": token_usage.get("output_tokens", 0),
                                            "total_tokens": token_usage.get("total_tokens", 0)
                                        })
                                    else:
                                         csv_rows.append({
                                            "filename": filename,
                                            "task": task_name,
                                            "response": f"Task error: {task_result}",
                                            "context_chunks_used": 0,
                                            "input_tokens": 0,
                                            "output_tokens": 0,
                                            "total_tokens": 0
                                        })
                         else:
                              csv_rows.append({
                                    "filename": filename,
                                    "task": "processing_failed",
                                    "response": f"Document processing failed: {doc_results}",
                                    "context_chunks_used": 0,
                                    "input_tokens": 0,
                                    "output_tokens": 0,
                                    "total_tokens": 0
                                })

                     if csv_rows:
                        results_df = pd.DataFrame(csv_rows)
                        column_order = [
                            "filename", "task", "response",
                            "context_chunks_used", "input_tokens",
                            "output_tokens", "total_tokens"
                        ]
                        existing_columns = [col for col in column_order if col in results_df.columns]
                        results_df = results_df[existing_columns]
                        for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                            if token_col in results_df.columns:
                                results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)

                  except Exception as csv_e:
                       compiled_details += f"\nError compiling partial results DataFrame: {csv_e}"
                       results_df = pd.DataFrame({"Error": [f"Could not compile results: {csv_e}"]}) # Indicate error in DataFrame too


             return status, compiled_details, results_df # Return DataFrame even on error


        # Compile results for display if no workflow-level error
        if task_results:
            compiled_details = ""
            csv_rows = [] # Prepare data for DataFrame display

            for filename, results in task_results.items():
                 compiled_details += f"\n📄 DOCUMENT: {filename}\n" + "─" * 50 + "\n"
                 if "error" in results:
                     compiled_details += f"❌ Error: {results['error']}\n"
                     csv_rows.append({
                        "filename": filename,
                        "task": "document_error",
                        "response": results["error"],
                        "context_chunks_used": 0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "total_tokens": 0
                    })
                 else:
                    for task_name, task_result in results.items():
                        if isinstance(task_result, dict):
                            response = task_result.get("response", "No response")
                            chunks_used = task_result.get("context_chunks_used", 0)
                            token_usage = task_result.get("token_usage", {})

                            compiled_details += f"\n🎯 {task_name.upper()}:\n"
                            compiled_details += f"   📝 Response: {response[:200]}...\n" # Limit display length
                            compiled_details += f"   📚 Chunks used: {chunks_used}\n"
                            if token_usage:
                                compiled_details += f"   🔢 Tokens: {token_usage.get('total_tokens', 0)}\n"

                            csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": response,
                                "context_chunks_used": chunks_used,
                                "input_tokens": token_usage.get("input_tokens", 0),
                                "output_tokens": token_usage.get("output_tokens", 0),
                                "total_tokens": token_usage.get("total_tokens", 0)
                            })
                        else:
                             compiled_details += f"\n🎯 {task_name.upper()}:\n"
                             compiled_details += f"   ❌ Task Error: {task_result}\n"
                             csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": f"Task error: {task_result}",
                                "context_chunks_used": 0,
                                "input_tokens": 0,
                                "output_tokens": 0,
                                "total_tokens": 0
                            })


            # Create DataFrame for the summary output
            if csv_rows:
                results_df = pd.DataFrame(csv_rows)
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in results_df.columns]
                results_df = results_df[existing_columns]
                for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                     if token_col in results_df.columns:
                        results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)

            else:
                 results_df = pd.DataFrame({"Status": ["No task results generated."]})


            # Display token usage summary in the details output as well
            compiled_details += "\n\n📊 TOTAL TOKEN USAGE SUMMARY:\n"
            compiled_details += f"🔢 Total LLM Calls: {token_tracker.call_count}\n"
            compiled_details += f"📥 Total Input Tokens: {token_tracker.total_input_tokens:,}\n"
            compiled_details += f"📤 Total Output Tokens: {token_tracker.total_output_tokens:,}\n"
            compiled_details += f"🎯 Grand Total Tokens: {token_tracker.total_tokens:,}\n"
            if token_tracker.call_count > 0:
                compiled_details += f"📊 Average per call: {token_tracker.total_tokens/token_tracker.call_count:.1f} tokens\n"


            return status, compiled_details, results_df # Return DataFrame for display

        else:
            return status, "No task results generated.", pd.DataFrame() # Return empty DataFrame if no task results

    except Exception as e:
        error_msg = f"Workflow execution failed: {e}"
        print(f"❌ {error_msg}")
        # Attempt to return any partial state/error info
        partial_details = f"An unexpected error occurred during workflow execution: {e}"
        if 'final_state' in locals() and final_state:
             partial_details += f"\n\nPartial workflow state: {final_state.get('workflow_status', 'unknown')}"
             if final_state.get('error'):
                  partial_details += f"\nInternal state error: {final_state['error']}"
             if final_state.get('task_results'):
                  partial_details += "\nPartial task results available (check console for details)."
                  # Optionally add a simplified representation of partial task results to partial_details

        # Attempt to create a simple error DataFrame
        error_df = pd.DataFrame({"Error": [error_msg]})

        return "Workflow Failed", partial_details, error_df


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced RAG-Based Expense Claims Processing")
    gr.Markdown("Upload your expense claim document (PDF, JPG, PNG, etc.) to extract details.")

    file_input = gr.File(label="Upload Expense Document", file_types=[".pdf", ".jpg", ".jpeg", ".png", ".tiff"])
    process_button = gr.Button("Process Document")

    status_output = gr.Textbox(label="Processing Status", max_lines=5)
    details_output = gr.Textbox(label="Extracted Details (Raw)", interactive=False, lines=20)
    summary_output = gr.Dataframe(label="Extracted Details (Structured)") # Use Dataframe for structured output

    process_button.click(
        process_document_ui,
        inputs=file_input,
        outputs=[status_output, details_output, summary_output]
    )

# Launch the Gradio interface (only when running the script, not needed during subtask execution)
# demo.launch()

**Reasoning**:
The previous execution failed because `TokenUsageTracker` and other classes (`EnhancedRAGExpenseProcessor`, etc.) were not defined. These classes were defined in previous code cells according to the notebook history, but the execution environment lost their definitions. I need to re-execute the code cells that contain the definitions of all necessary components (`TokenUsageTracker`, `EnhancedRAGExpenseProcessor`, `create_enhanced_rag_workflow` function, and all classes that `EnhancedRAGExpenseProcessor` depends on) to make them available in the current environment. Once these are defined, I will re-attempt to initialize the global processor and workflow instances and then define the `process_document_ui` function which uses them and launch the Gradio app.



In [None]:
# Re-execute code from previous cells to ensure all necessary components are defined.
# This includes TokenUsageTracker, ClaimDocument, FilenameBasedDocumentManager,
# EnhancedOCRProcessor, EnhancedIsolatedVectorStore, EnhancedExpenseTaskManager,
# EnhancedRAGExpenseProcessor, and create_enhanced_rag_workflow.

# ================================
# STEP 2: TOKEN USAGE TRACKING (Re-definition)
# ================================

class TokenUsageTracker:
    """Track token usage across all LLM calls"""

    def __init__(self):
        self.call_history = []
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.total_tokens = 0
        self.call_count = 0

    def track_call(self, operation: str, filename: str, task: str, response):
        """Track a single LLM call and extract usage info"""

        usage_info = {
            "operation": operation,
            "filename": filename,
            "task": task,
            "timestamp": datetime.now().isoformat(),
            "input_tokens": 0,
            "output_tokens": 0,
            "total_tokens": 0,
            "duration_ms": 0
        }

        # Extract token usage from response
        try:
            if hasattr(response, 'usage_metadata') and response.usage_metadata:
                usage_info["input_tokens"] = response.usage_metadata.get('input_tokens', 0)
                usage_info["output_tokens"] = response.usage_metadata.get('output_tokens', 0)
                usage_info["total_tokens"] = response.usage_metadata.get('total_tokens', 0)

            # Fallback: try response_metadata
            elif hasattr(response, 'response_metadata') and response.response_metadata:
                metadata = response.response_metadata
                usage_info["input_tokens"] = metadata.get('prompt_eval_count', 0)
                usage_info["output_tokens"] = metadata.get('eval_count', 0)
                usage_info["total_tokens"] = usage_info["input_tokens"] + usage_info["output_tokens"]
                usage_info["duration_ms"] = metadata.get('total_duration', 0) // 1000000  # Convert to ms

        except Exception as e:
            print(f"⚠️ Could not extract token usage: {e}")

        # Update totals
        self.total_input_tokens += usage_info["input_tokens"]
        self.total_output_tokens += usage_info["output_tokens"]
        self.total_tokens += usage_info["total_tokens"]
        self.call_count += 1

        # Store call history
        self.call_history.append(usage_info)

        # Print usage info
        self.print_usage_info(usage_info)

        return usage_info


    def print_usage_info(self, usage_info: Dict[str, Any]):
        """Print formatted usage information"""
        print(f"📊 TOKEN USAGE - {usage_info['operation']} | {usage_info['filename']} | {usage_info['task']}")
        print(f"   📥 Input: {usage_info['input_tokens']} tokens")
        print(f"   📤 Output: {usage_info['output_tokens']} tokens")
        print(f"   🔢 Total: {usage_info['total_tokens']} tokens")
        if usage_info['duration_ms'] > 0:
            print(f"   ⏱️ Duration: {usage_info['duration_ms']}ms")
        print()

    def print_summary(self):
        """Print overall token usage summary"""
        print("=" * 60)
        print("📊 TOTAL TOKEN USAGE SUMMARY")
        print("=" * 60)
        print(f"🔢 Total LLM Calls: {self.call_count}")
        print(f"📥 Total Input Tokens: {self.total_input_tokens:,}")
        print(f"📤 Total Output Tokens: {self.total_output_tokens:,}")
        print(f"🎯 Grand Total Tokens: {self.total_tokens:,}")

        if self.call_count > 0:
            print(f"📊 Average per call: {self.total_tokens/self.call_count:.1f} tokens")
        print()

# Global token tracker
token_tracker = TokenUsageTracker()

print("\n✅ SETUP COMPLETE - TokenUsageTracker!")
print("=" * 70)

# ================================
# STEP 3: FILENAME-BASED DOCUMENT MANAGEMENT (Re-definition)
# ================================

@dataclass
class ClaimDocument:
    """Document with filename-based identification"""
    filename: str  # Primary identifier (no more UUIDs!)
    file_path: str
    raw_text: str
    chunks: List[str]
    metadata: Dict[str, Any]
    processed_timestamp: datetime

class FilenameBasedDocumentManager:
    """Manages documents using filenames as primary identifiers"""

    def __init__(self):
        self.documents_registry = {}  # filename -> ClaimDocument
        self.chunk_to_file_map = {}  # chunk_id -> filename

    def register_document(self, file_path: str, raw_text: str) -> str:
        """Register document using filename as ID"""

        filename = Path(file_path).stem  # Get filename without extension

        print(f"📋 REGISTERING DOCUMENT: {filename}")
        print(f"   📁 Source: {Path(file_path).name}")
        print(f"   📄 Text length: {len(raw_text)} characters")

        # Create isolated chunks for this document
        chunks = self.create_document_chunks(raw_text, filename)

        claim_doc = ClaimDocument(
            filename=filename,
            file_path=file_path,
            raw_text=raw_text,
            chunks=chunks,
            metadata={
                "file_name": Path(file_path).name,
                "file_extension": Path(file_path).suffix,
                "chunk_count": len(chunks),
                "source": "ocr_extraction"
            },
            processed_timestamp=datetime.now()
        )

        self.documents_registry[filename] = claim_doc

        # Update chunk mapping
        for i, chunk in enumerate(chunks):
            chunk_id = f"{filename}_chunk_{i}"
            self.chunk_to_file_map[chunk_id] = filename

        print(f"✅ Document registered: {filename} with {len(chunks)} chunks")
        return filename

    def create_document_chunks(self, text: str, filename: str) -> List[str]:
        """Create chunks with filename-specific context isolation"""

        print(f"🔪 CHUNKING DOCUMENT: {filename}")

        lines = text.split('\n')
        chunks = []
        current_chunk = []
        current_length = 0
        max_chunk_size = 500

        # Expense document section markers
        section_markers = [
            'total', 'amount', 'date', 'vendor', 'receipt', 'invoice',
            'item', 'quantity', 'price', 'tax', 'subtotal'
        ]

        for line in lines:
            line = line.strip()
            if not line:
                continue

            line_length = len(line)
            is_section_start = any(marker in line.lower() for marker in section_markers)

            if (current_length + line_length > max_chunk_size) or \
               (is_section_start and current_chunk and current_length > 200):

                chunk_text = '\n'.join(current_chunk)
                if chunk_text.strip():
                    # Add filename isolation metadata to chunk
                    isolated_chunk = f"[DOCUMENT: {filename}]\n{chunk_text}"
                    chunks.append(isolated_chunk)

                current_chunk = [line]
                current_length = line_length
            else:
                current_chunk.append(line)
                current_length += line_length + 1

        # Add final chunk
        if current_chunk:
            chunk_text = '\n'.join(current_chunk)
            if chunk_text.strip():
                isolated_chunk = f"[DOCUMENT: {filename}]\n{chunk_text}"
                chunks.append(isolated_chunk)

        print(f"   🔪 Created {len(chunks)} chunks (avg {len(text)//len(chunks) if chunks else 0} chars each)")
        return chunks



    def get_document_context(self, filename: str) -> Optional[ClaimDocument]:
        """Get complete context for a specific document"""
        return self.documents_registry.get(filename)

    def list_all_documents(self) -> List[str]:
        """List all registered filenames"""
        return list(self.documents_registry.keys())

print("\n✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!")
print("=" * 70)

# ================================
# STEP 4: ENHANCED OCR PROCESSOR (Re-definition)
# ================================

class EnhancedOCRProcessor:
    """OCR processing with detailed progress tracking"""

    def __init__(self):
        self.supported_formats = ['.pdf', '.jpg', '.jpeg', '.png', '.tiff']

    def extract_text_from_document(self, file_path: str) -> str:
        """Extract text with detailed progress tracking"""

        filename = Path(file_path).name
        print(f"🔍 EXTRACTING TEXT FROM: {filename}")
        print(f"   📁 Full path: {file_path}")
        print(f"   📊 File size: {Path(file_path).stat().st_size / 1024:.1f} KB")

        try:
            from unstructured.partition.auto import partition

            print(f"   🔄 Processing with UnstructuredIO...")

            # Process document with UnstructuredIO
            elements = partition(filename=file_path)

            print(f"   📋 Found {len(elements)} document elements")

            # Extract text from all elements
            full_text = ""
            for i, element in enumerate(elements):
                if hasattr(element, 'text') and element.text:
                    full_text += element.text + "\n"
                    if i < 5:  # Show first few elements
                        print(f"     Element {i+1}: {element.text[:50]}...")

            # Clean and normalize text
            full_text = self.clean_extracted_text(full_text)

            print(f"   ✅ Extracted {len(full_text)} characters")
            print(f"   📝 Text preview: {full_text[:100]}...")
            return full_text

        except Exception as e:
            print(f"   ❌ OCR extraction failed: {e}")
            return ""

    def clean_extracted_text(self, text: str) -> str:
        """Clean extracted text with progress info"""
        if not text:
            return ""

        original_length = len(text)
        lines = text.split('\n')
        cleaned_lines = []

        for line in lines:
            line = line.strip()
            if line and len(line) > 2:
                cleaned_lines.append(line)

        cleaned_text = '\n'.join(cleaned_lines)
        print(f"   🧹 Cleaned: {original_length} → {len(cleaned_text)} chars ({len(cleaned_lines)} lines)")

        return cleaned_text

print("\n✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!")
print("=" * 70)

# ================================
# STEP 5: ENHANCED VECTOR STORE (Re-definition)
# ================================

class EnhancedIsolatedVectorStore:
    """ChromaDB with enhanced tracking and filename-based isolation"""

    def __init__(self, embedding_model: str = "nomic-embed-text"): # Use default or passed model
        import chromadb

        self.embedding_model = embedding_model

        print(f"🗄️ INITIALIZING VECTOR STORE")
        print(f"   🤖 Embedding Model: {embedding_model}")

        # Initialize ChromaDB client using new API
        self.client = chromadb.PersistentClient(path="./chroma_db")

        # Create collection
        self.collection = self.client.get_or_create_collection(
            name="filename_based_expense_claims",
            metadata={"hnsw:space": "cosine"}
        )

        print(f"   ✅ ChromaDB initialized")

    def embed_text(self, text: str, filename: str = "unknown") -> List[float]:
        """Generate embeddings with progress tracking"""

        print(f"🔢 GENERATING EMBEDDING: {filename}")
        print(f"   📝 Text length: {len(text)} chars")

        try:
            # Ensure ollama is imported and available
            import ollama
            response = ollama.embeddings(model=self.embedding_model, prompt=text)
            embedding = response['embedding']
            print(f"   ✅ Generated {len(embedding)}-dimensional embedding")
            return embedding
        except Exception as e:
            print(f"   ❌ Embedding error: {e}")
            return []

    def add_document_chunks(self, filename: str, chunks: List[str], metadata: Dict[str, Any]):
        """Add chunks for a specific document with detailed tracking"""

        print(f"📚 ADDING CHUNKS TO VECTOR STORE: {filename}")
        print(f"   📊 Number of chunks: {len(chunks)}")

        embeddings = []
        chunk_ids = []
        metadatas = []

        for i, chunk in enumerate(chunks):
            print(f"   🔄 Processing chunk {i+1}/{len(chunks)}")

            # Generate embedding
            embedding = self.embed_text(chunk, f"{filename}_chunk_{i}")
            if not embedding:
                print(f"   ⚠️ Skipping chunk {i+1} - no embedding generated")
                continue

            chunk_id = f"{filename}_chunk_{i}"
            chunk_metadata = {
                **metadata,
                "filename": filename,
                "chunk_index": i,
                "chunk_id": chunk_id,
                "isolated": True
            }

            embeddings.append(embedding)
            chunk_ids.append(chunk_id)
            metadatas.append(chunk_metadata)

        # Add to ChromaDB
        if embeddings:
            self.collection.add(
                embeddings=embeddings,
                documents=chunks,
                metadatas=metadatas,
                ids=chunk_ids
            )

            print(f"   ✅ Added {len(embeddings)} chunks to vector store")
        else:
            print(f"   ❌ No chunks added - all embeddings failed")

    def query_document_specific(self, query: str, filename: str, n_results: int = 3) -> Dict[str, Any]:
        """Query specific document only - prevents cross-contamination"""

        print(f"🔍 QUERYING VECTOR STORE: {filename}")
        print(f"   ❓ Query: {query}")
        print(f"   📊 Requesting {n_results} results")

        query_embedding = self.embed_text(query, f"query_{filename}")
        if not query_embedding:
            return {"error": "Failed to generate query embedding"}

        # Query with filename filter to ensure isolation
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            where={"filename": filename},  # CRITICAL: Isolates to specific document
            include=["documents", "metadatas", "distances"]
        )

        print(f"   ✅ Found {len(results['documents'][0]) if results['documents'] else 0} relevant chunks")

        return {
            "documents": results['documents'][0] if results['documents'] else [],
            "metadatas": results['metadatas'][0] if results['metadatas'] else [],
            "distances": results['distances'][0] if results['distances'] else [],
            "filename": filename
        }

    def get_collection_stats(self) -> Dict[str, Any]:
        """Get detailed statistics about stored documents"""

        print("📊 GENERATING COLLECTION STATISTICS")

        count = self.collection.count()

        # Get unique filenames
        all_metadata = self.collection.get(include=["metadatas"])
        filenames = set()
        if all_metadata and all_metadata['metadatas']: # Added check for all_metadata existence
            for meta in all_metadata['metadatas']:
                if 'filename' in meta:
                    filenames.add(meta['filename'])

        stats = {
            "total_chunks": count,
            "unique_documents": len(filenames),
            "filenames": list(filenames)
        }

        print(f"   📚 Total chunks: {stats['total_chunks']}")
        print(f"   📄 Unique documents: {stats['unique_documents']}")
        print(f"   📝 Documents: {', '.join(stats['filenames'])}")

        return stats

print("\n✅ SETUP COMPLETE - ENHANCED VECTOR STORE!")
print("=" * 70)

# ================================
# STEP 6: ENHANCED EXPENSE TASK MANAGER (Re-definition)
# ================================

class EnhancedExpenseTaskManager:
    """Manages predefined expense extraction tasks with better tracking"""

    def __init__(self):
        self.predefined_tasks = {
            "extract_amount": {
                "query": "total amount due payment cost price sum money dollar",
                "description": "Extract the total amount from this expense document",
                "expected_format": "numeric value with currency"
            },
            "extract_date": {
                "query": "date transaction purchase invoice receipt timestamp when",
                "description": "Extract the date from this expense document",
                "expected_format": "date in YYYY-MM-DD format"
            },
            "extract_vendor": {
                "query": "vendor merchant company business supplier store restaurant hotel",
                "description": "Extract vendor/merchant name from this expense document",
                "expected_format": "company or business name"
            },
            "extract_category": {
                "query": "category type classification expense kind service product item",
                "description": "Determine expense category from this document",
                "expected_format": "expense category classification"
            },
            "extract_items": {
                "query": "items products services line items purchases description details",
                "description": "Extract itemized details from this expense document",
                "expected_format": "list of items or services"
            },
            "extract_tax": {
                "query": "tax VAT GST sales tax tax rate percentage",
                "description": "Extract tax information from this expense document",
                "expected_format": "tax amount and rate"
            }
        }

    def get_task_info(self, task_name: str) -> Dict[str, str]:
        """Get complete task information"""
        return self.predefined_tasks.get(task_name, {})

    def list_available_tasks(self) -> List[str]:
        """List all available extraction tasks"""
        return list(self.predefined_tasks.keys())

print("\n✅ SETUP COMPLETE - ENHANCED EXPENSE TASK MANAGER!")
print("=" * 70)


# ================================
# STEP 7: ENHANCED RAG PROCESSOR (Re-definition)
# ================================

class EnhancedRAGExpenseProcessor:
    """RAG-based expense processor with comprehensive tracking"""

    def __init__(self, text_model: str = "gemma3:1b"): # Use default or passed model
        from langchain_ollama import ChatOllama

        print(f"🚀 INITIALIZING RAG EXPENSE PROCESSOR")
        print(f"   🤖 Text Model: {text_model}")

        self.llm = ChatOllama(
            model=text_model,
            temperature=0.1,
            base_url="http://127.0.0.1:11434"
        )

        self.vector_store = EnhancedIsolatedVectorStore()
        self.task_manager = EnhancedExpenseTaskManager()
        self.document_manager = FilenameBasedDocumentManager()
        self.ocr_processor = EnhancedOCRProcessor()

        print("   ✅ All components initialized")

    def ingest_document(self, file_path: str) -> str:
        """INGESTION PHASE: Process document and store in vector DB"""

        filename = Path(file_path).name
        print("\n" + "="*70)
        print(f"🔄 INGESTION PHASE STARTING")
        print(f"📄 FILE: {filename}")
        print(f"📁 PATH: {file_path}")
        print("="*70)

        # Step 1: OCR extraction
        raw_text = self.ocr_processor.extract_text_from_document(file_path)
        if not raw_text:
            print("❌ INGESTION FAILED: No text extracted")
            return None

        # Step 2: Register document with filename-based system
        filename_id = self.document_manager.register_document(file_path, raw_text)

        # Step 3: Get document context
        document = self.document_manager.get_document_context(filename_id)

        # Step 4: Store in vector database
        metadata = {
            **document.metadata,
            "ingestion_timestamp": datetime.now().isoformat()
        }

        self.vector_store.add_document_chunks(
            filename=filename_id,
            chunks=document.chunks,
            metadata=metadata
        )

        print(f"✅ INGESTION COMPLETED: {filename_id}")
        print("="*70)
        return filename_id

    def process_expense_task(self, filename: str, task_name: str) -> Dict[str, Any]:
        """INFERENCE PHASE: Process specific task for document"""

        print(f"\n🎯 INFERENCE PHASE STARTING")
        print(f"📄 DOCUMENT: {filename}")
        print(f"🎯 TASK: {task_name}")
        print("-" * 50)

        # Step 1: Get task information
        task_info = self.task_manager.get_task_info(task_name)
        if not task_info:
            return {"error": f"Unknown task: {task_name}"}

        task_query = task_info.get("query", "")
        task_description = task_info.get("description", "")

        print(f"📋 Task Description: {task_description}")
        print(f"🔍 Search Query: {task_query}")

        # Step 2: Retrieve relevant chunks (ISOLATED to this document)
        retrieval_results = self.vector_store.query_document_specific(
            query=task_query,
            filename=filename,
            n_results=3
        )

        if retrieval_results.get("error"):
            return retrieval_results

        # Step 3: Prepare optimized context
        context = self.optimize_context(retrieval_results, task_name)

        # Step 4: Generate response with LLM (WITH TOKEN TRACKING)
        response_text, token_usage = self.generate_task_response_with_tracking(
            context, task_name, task_description, filename
        )

        result = {
            "task": task_name,
            "filename": filename,
            "response": response_text,
            "context_chunks_used": len(retrieval_results["documents"]),
            "token_usage": token_usage
        }

        print(f"✅ INFERENCE COMPLETED: {task_name} for {filename}")
        print("-" * 50)

        return result

    def optimize_context(self, retrieval_results: Dict[str, Any], task_name: str) -> str:
        """CONTEXT OPTIMIZATION: Reduce context overloading"""

        documents = retrieval_results.get("documents", [])
        distances = retrieval_results.get("distances", [])
        filename = retrieval_results.get("filename", "unknown")

        print(f"🔧 OPTIMIZING CONTEXT: {filename}")
        print(f"   📊 Raw chunks: {len(documents)}")

        if not documents:
            return "No relevant context found"

        # Rank documents by relevance
        doc_scores = list(zip(documents, distances))
        doc_scores.sort(key=lambda x: x[1])

        optimized_chunks = []
        total_length = 0
        max_context_length = 1500

        for i, (doc, score) in enumerate(doc_scores):
            # Remove document prefix from chunks
            clean_doc = doc.replace(f"[DOCUMENT: {filename}]\n", "")

            if total_length + len(clean_doc) <= max_context_length:
                optimized_chunks.append(clean_doc)
                total_length += len(clean_doc)
                print(f"   ✅ Chunk {i+1}: {len(clean_doc)} chars (relevance: {score:.3f})")
            else:
                remaining_space = max_context_length - total_length
                if remaining_space > 100:
                    truncated = clean_doc[:remaining_space] + "..."
                    optimized_chunks.append(truncated)
                    print(f"   ✂️ Chunk {i+1}: truncated to {len(truncated)} chars")
                break

        context = "\n\n---\n\n".join(optimized_chunks)
        print(f"   🎯 Final context: {len(context)} chars from {len(optimized_chunks)} chunks")

        return context

    def generate_task_response_with_tracking(self, context: str, task_name: str, task_description: str, filename: str) -> Tuple[str, Dict[str, Any]]:
        """Generate LLM response with token usage tracking"""

        print(f"🤖 GENERATING LLM RESPONSE: {task_name} | {filename}")

        prompt = f"""You are an expert expense analyst. {task_description}

CONTEXT FROM EXPENSE DOCUMENT ({filename}):
{context}

TASK: {task_name}
INSTRUCTION: {task_description}

Based ONLY on the context provided above, extract the requested information. Be precise and factual. If the information is not clearly present in the context, state "Information not found in provided context."

Response:"""

        print(f"   📝 Prompt length: {len(prompt)} characters")

        try:
            response = self.llm.invoke(prompt)

            # Track token usage
            token_usage = token_tracker.track_call("llm_inference", filename, task_name, response)

            return response.content.strip(), token_usage

        except Exception as e:
            error_msg = f"Error generating response: {e}"
            print(f"   ❌ {error_msg}")
            return error_msg, {}

    def process_all_tasks_for_document(self, filename: str) -> Dict[str, Any]:
        """Process all predefined tasks for a document"""

        print(f"\n📊 PROCESSING ALL TASKS FOR: {filename}")
        print("="*50)

        tasks = self.task_manager.list_available_tasks()
        results = {}

        for i, task in enumerate(tasks, 1):
            print(f"\n[{i}/{len(tasks)}] Starting task: {task}")
            result = self.process_expense_task(filename, task)
            results[task] = result

        print(f"\n✅ ALL TASKS COMPLETED FOR: {filename}")
        return results

print("\n✅ SETUP COMPLETE - ENHANCED RAG PROCESSOR!")
print("=" * 70)


# ================================
# STEP 8: ENHANCED WORKFLOW (Re-definition)
# ================================

from langgraph.graph import StateGraph
from typing import TypedDict

class EnhancedRAGWorkflowState(TypedDict):
    """Enhanced state for RAG workflow"""
    file_paths: List[str]
    current_file_index: int
    processed_filenames: List[str]
    current_filename: str
    task_results: Dict[str, Dict[str, Any]]
    workflow_status: str
    error: Optional[str]

def create_enhanced_rag_workflow(processor: EnhancedRAGExpenseProcessor) -> StateGraph:
    """Create enhanced LangGraph workflow for RAG processing"""

    def enhanced_ingestion_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced ingestion with detailed tracking"""

        print("\n" + "🔄 WORKFLOW: ENHANCED INGESTION PHASE STARTING")
        print("="*70)

        file_paths = state.get("file_paths", [])
        processed_filenames = []

        if not file_paths:
            state["workflow_status"] = "ingestion_failed"
            state["error"] = "No file paths provided for ingestion."
            print("❌ INGESTION FAILED: No file paths provided.")
            return state


        for i, file_path in enumerate(file_paths, 1):
            print(f"\n[{i}/{len(file_paths)}] Processing file: {Path(file_path).name}")

            try:
                # Use the processor to ingest the document
                filename = processor.ingest_document(file_path)
                if filename:
                    processed_filenames.append(filename)
                    print(f"✅ Successfully ingested: {filename}")
                else:
                    print(f"❌ Failed to ingest: {Path(file_path).name}")

            except Exception as e:
                print(f"❌ Error ingesting {Path(file_path).name}: {e}")
                state["error"] = str(e) # Store the error in state
                # Decide if you want to stop on first ingestion error or continue
                # For now, let's continue to process other files if possible
                processed_filenames.append({"error": str(e), "filename": Path(file_path).name})


        state["processed_filenames"] = processed_filenames
        state["workflow_status"] = "ingestion_complete" if any(isinstance(f, str) for f in processed_filenames) else "ingestion_failed" # Check if at least one file was successfully processed

        print(f"\n📊 INGESTION PHASE COMPLETED")
        successful_count = sum(1 for f in processed_filenames if isinstance(f, str))
        print(f"   ✅ Successfully processed: {successful_count} files")
        print(f"   ❌ Failed: {len(file_paths) - successful_count} files")


        return state

    def enhanced_task_processing_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced task processing with detailed tracking"""

        print("\n🎯 WORKFLOW: ENHANCED TASK PROCESSING PHASE STARTING")
        print("="*70)

        # Filter out ingestion errors before processing tasks
        processable_filenames = [f for f in state.get("processed_filenames", []) if isinstance(f, str)]
        task_results = state.get("task_results", {}) # Initialize or get existing results

        if not processable_filenames:
            state["workflow_status"] = "processing_skipped"
            print("⚠️ TASK PROCESSING SKIPPED: No documents successfully ingested.")
            return state


        for i, filename in enumerate(processable_filenames, 1):
            print(f"\n[{i}/{len(processable_filenames)}] Processing tasks for: {filename}")

            try:
                results = processor.process_all_tasks_for_document(filename)
                task_results[filename] = results
                print(f"✅ Completed all tasks for: {filename}")

            except Exception as e:
                print(f"❌ Error processing tasks for {filename}: {e}")
                task_results[filename] = {"error": str(e)}
                state["error"] = str(e) # Store the error in state


        state["task_results"] = task_results
        state["workflow_status"] = "processing_complete"

        print(f"\n📊 TASK PROCESSING PHASE COMPLETED")
        print(f"   📄 Documents processed: {len(task_results)}")

        return state

    def enhanced_results_compilation_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced results compilation with detailed stats and CSV export"""

        print("\n📊 WORKFLOW: ENHANCED RESULTS COMPILATION STARTING")
        print("="*70)

        task_results = state.get("task_results", {})
        # token_tracker is a global instance assumed to be available
        # processor instance (and its vector_store) is also assumed to be available

        if not task_results:
             state["workflow_status"] = "compilation_skipped"
             print("⚠️ RESULTS COMPILATION SKIPPED: No task results to compile.")
             token_tracker.print_summary() # Print summary even if compilation skipped
             return state


        # Compile detailed statistics
        total_documents_attempted_ingestion = len(state.get("file_paths", [])) # Count original files
        successfully_ingested_filenames = [f for f in state.get("processed_filenames", []) if isinstance(f, str)]
        total_documents_successfully_processed = len(task_results) # Count documents with task results (successful or not)


        # Save results with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = f"enhanced_rag_expense_results_{timestamp}.json"
        csv_file = f"enhanced_rag_expense_results_{timestamp}.csv"

        # Create comprehensive results package
        comprehensive_results = {
            "timestamp": timestamp,
            "summary": {
                "total_documents_attempted_ingestion": total_documents_attempted_ingestion,
                "successfully_ingested_documents": len(successfully_ingested_filenames),
                "documents_with_task_results": total_documents_successfully_processed,
                "failed_ingestion": total_documents_attempted_ingestion - len(successfully_ingested_filenames)

            },
            "token_usage_summary": {
                "total_calls": token_tracker.call_count,
                "total_input_tokens": token_tracker.total_input_tokens,
                "total_output_tokens": token_tracker.total_output_tokens, # Fixed typo here
                "total_tokens": token_tracker.total_tokens
            },
            "document_results": task_results,
            "token_call_history": token_tracker.call_history,
            "initial_state": state.get("initial_state", {}) # Include initial state for debugging
        }

        # Save JSON results
        try:
            with open(results_file, 'w') as f:
                json.dump(comprehensive_results, f, indent=2, default=str)
            print(f"💾 JSON RESULTS SAVED TO: {results_file}")
        except Exception as e:
            print(f"❌ Error saving JSON results: {e}")
            state["error"] = f"Error saving JSON results: {e}"


        # Create CSV from results
        csv_rows = []

        # Add rows for documents that failed ingestion
        failed_ingestion_info = [f for f in state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
        for fail_info in failed_ingestion_info:
             csv_rows.append({
                "filename": fail_info.get("filename", "unknown"),
                "task": "ingestion_failed",
                "response": fail_info.get("error", "Unknown ingestion error"),
                "context_chunks_used": 0,
                "input_tokens": 0,
                "output_tokens": 0,
                "total_tokens": 0
            })


        for filename, doc_results in task_results.items():
            if "error" in doc_results:
                # Add error row for document that failed task processing after successful ingestion
                csv_rows.append({
                    "filename": filename,
                    "task": "processing_failed",
                    "response": doc_results["error"],
                    "context_chunks_used": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "total_tokens": 0
                })
            else:
                # Process each task for this document
                for task_name, task_result in doc_results.items():
                    if isinstance(task_result, dict):
                        token_usage = task_result.get("token_usage", {})
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": task_result.get("response", ""),
                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                            "input_tokens": token_usage.get("input_tokens", 0),
                            "output_tokens": token_usage.get("output_tokens", 0),
                            "total_tokens": token_usage.get("total_tokens", 0)
                        })
                    else:
                         # Handle task-specific errors (if task_result is not a dict but an error string)
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": f"Task error: {task_result}",
                            "context_chunks_used": 0,
                            "input_tokens": 0,
                            "output_tokens": 0,
                            "total_tokens": 0
                        })


        # Save CSV
        df = pd.DataFrame(csv_rows) # Create DataFrame even if empty

        if not df.empty: # Check if DataFrame is not empty before processing
            try:
                # Reorder columns for better readability - handle missing columns gracefully
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in df.columns]
                df = df[existing_columns]


                # Convert token columns to numeric, handling errors
                for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                     if token_col in df.columns:
                        df[token_col] = pd.to_numeric(df[token_col], errors='coerce').fillna(0).astype(int)


                # Save to CSV
                df.to_csv(csv_file, index=False, encoding='utf-8')
                print(f"💾 CSV RESULTS SAVED TO: {csv_file}")

                # Display summary statistics from CSV
                print(f"\n📊 CSV Summary:")
                print(f"   📄 Total rows: {len(df)}")
                # Ensure 'filename' column exists before calling nunique
                if 'filename' in df.columns:
                    print(f"   📁 Documents: {df['filename'].nunique()}")
                     # Handle case where no tasks were processed successfully
                    successful_task_rows = df[~df['task'].isin(['ingestion_failed', 'processing_failed', 'workflow_error'])]
                    if not successful_task_rows.empty and 'filename' in successful_task_rows.columns:
                         print(f"   🎯 Avg tasks per successful doc: {successful_task_rows.groupby('filename').size().mean():.1f}")
                    else:
                        print("   🎯 Avg tasks per successful doc: N/A (No successful tasks)")

                # Ensure 'total_tokens' column exists and is numeric before summing
                if 'total_tokens' in df.columns:
                    try:
                        total_tokens_sum = df['total_tokens'].sum()
                        print(f"   🔢 Total tokens used: {total_tokens_sum:,}")
                         # Update comprehensive_results with the sum from CSV if needed
                        comprehensive_results['token_usage_summary']['total_tokens_from_csv'] = total_tokens_sum

                    except Exception as e:
                         print(f"⚠️ Could not calculate total tokens from CSV: {e}")
                else:
                    print("⚠️ 'total_tokens' column not found in CSV.")

            except Exception as e:
                print(f"❌ Error processing or saving CSV results: {e}")
                state["error"] = f"Error processing or saving CSV results: {e}"

        else:
             print("⚠️ No CSV rows generated. Skipping CSV save.")


        # Also save a summary CSV with aggregated data per document
        summary_csv_file = f"enhanced_rag_expense_summary_{timestamp}.csv"
        summary_rows = []

        # Include documents that failed ingestion in summary
        for fail_info in failed_ingestion_info:
            summary_rows.append({
                "filename": fail_info.get("filename", "unknown"),
                "status": "Ingestion Failed",
                "error_message": fail_info.get("error", "Unknown error")
            })


        for filename, doc_results in task_results.items():
            if "error" not in doc_results:
                row = {"filename": filename, "status": "Processed"}

                # Extract key information from each task
                for task_name in ["extract_amount", "extract_date", "extract_vendor",
                                "extract_category", "extract_tax", "extract_items", "extract_tax"]: # Include all relevant tasks
                    if task_name in doc_results and isinstance(doc_results[task_name], dict):
                        response = doc_results[task_name].get("response", "")
                        # Clean the response (take first line or first 100 chars)
                        cleaned = response.split('\n')[0][:100] if response else ""
                        row[task_name] = cleaned
                    elif task_name in doc_results:
                         # Handle task-specific errors
                         row[task_name] = f"Error: {doc_results[task_name]}"
                    else:
                        row[task_name] = "Task Not Run"


                # Add token totals
                total_tokens = sum(
                    doc_results.get(task, {}).get("token_usage", {}).get("total_tokens", 0)
                    for task in doc_results if isinstance(doc_results.get(task), dict)
                )
                row["total_tokens_used"] = total_tokens

                summary_rows.append(row)
            else:
                 # Add document that failed task processing after ingestion
                 summary_rows.append({
                     "filename": filename,
                     "status": "Task Processing Failed",
                     "error_message": doc_results["error"]
                 })


        if summary_rows:
            try:
                summary_df = pd.DataFrame(summary_rows)
                # Ensure 'status' and 'error_message' columns exist and are placed early
                summary_column_order = ["filename", "status", "error_message"] + [col for col in summary_df.columns if col not in ["filename", "status", "error_message"]]
                summary_df = summary_df.get(summary_column_order, summary_df) # Use .get to handle missing columns

                summary_df.to_csv(summary_csv_file, index=False, encoding='utf-8')
                print(f"💾 SUMMARY CSV SAVED TO: {summary_csv_file}")
            except Exception as e:
                print(f"❌ Error saving summary CSV: {e}")
                state["error"] = f"Error saving summary CSV: {e}"


        print(f"\n📊 FILES SAVED:")
        if os.path.exists(results_file): print(f"   📄 Detailed JSON: {results_file}")
        if os.path.exists(csv_file) and not df.empty: print(f"   📄 Detailed CSV: {csv_file}")
        if os.path.exists(summary_csv_file) and summary_rows: print(f"   📄 Summary CSV: {summary_csv_file}")
        print(f"   📄 Documents processed successfully (ingestion+tasks): {total_documents_successfully_processed}/{total_documents_attempted_ingestion}")

        # Print token usage summary
        token_tracker.print_summary()

        state["workflow_status"] = "complete"
        return state


    # Build enhanced workflow
    workflow = StateGraph(EnhancedRAGWorkflowState)

    workflow.add_node("enhanced_ingestion", enhanced_ingestion_node)
    workflow.add_node("enhanced_task_processing", enhanced_task_processing_node)
    workflow.add_node("enhanced_results_compilation", enhanced_results_compilation_node)

    workflow.set_entry_point("enhanced_ingestion")
    workflow.add_edge("enhanced_ingestion", "enhanced_task_processing")
    workflow.add_edge("enhanced_task_processing", "enhanced_results_compilation")
    workflow.set_finish_point("enhanced_results_compilation")

    return workflow.compile()

print("\n✅ SETUP COMPLETE - ENHANCED WORKFLOW!")
print("=" * 70)


# Initialize the RAG processor and workflow globally for efficiency in Gradio
# This avoids re-initializing the processor (and thus ChromaDB client/collection)
# on every file upload in the Gradio app.

# Ensure the TokenUsageTracker is also a persistent instance
try:
    # Re-initialize token_tracker to clear previous runs' data if desired,
    # or let it accumulate total usage across all runs.
    # For this demo, let's re-initialize for per-file/per-run tracking clarity in UI.
    token_tracker = TokenUsageTracker()
    processor = EnhancedRAGExpenseProcessor()
    # Create the workflow instance
    enhanced_rag_workflow = create_enhanced_rag_workflow(processor)
    print("\n✅ RAG Processor and Workflow initialized globally.")
except Exception as e:
    print(f"\n❌ Error initializing RAG components globally: {e}")
    processor = None # Set to None if initialization fails
    enhanced_rag_workflow = None # Set to None if initialization fails


# Update the Gradio UI function to call the workflow
def process_document_ui(file):
    """Function to receive the uploaded file, run the workflow, and display results."""
    # Use the globally initialized workflow and processor
    global enhanced_rag_workflow, processor, token_tracker # Added token_tracker here

    if file is None:
        return "Please upload a file.", "No file uploaded.", pd.DataFrame() # Return empty DataFrame

    file_path = file.name # Gradio provides the temporary path here

    status = f"Received file: {file_path}. Starting RAG workflow..."
    print(status) # Print status to console

    # Re-initialize token tracker for this specific run to track usage per file upload
    token_tracker = TokenUsageTracker()
    print("\n📊 Token usage tracker reset for new upload.")


    # Check if workflow initialization was successful
    if enhanced_rag_workflow is None:
         error_msg = "RAG Workflow failed to initialize during startup. Cannot process file."
         print(f"❌ {error_msg}")
         return error_msg, "RAG components failed to initialize. Check server logs.", pd.DataFrame()


    # Execute enhanced workflow with the uploaded file path
    initial_state = {
        "file_paths": [file_path], # Pass the uploaded file path as a list
        "current_file_index": 0, # Not strictly used with single file processing, but keep for state structure
        "processed_filenames": [],
        "current_filename": "", # Not strictly used with single file processing
        "task_results": {},
        "workflow_status": "initialized",
        "error": None,
        "initial_state": {"file_paths": [file_path]} # Store initial state for compilation node
    }

    final_state = None # Initialize final_state to None
    try:
        print(f"Executing workflow with state: {initial_state}")
        final_state = enhanced_rag_workflow.invoke(initial_state)
        workflow_final_status = final_state.get('workflow_status', 'unknown')
        status = f"Workflow completed with status: {workflow_final_status}"
        print(status) # Print final status
        print(f"Final state: {final_state}") # Print final state for debugging


        # Process final state to display results
        task_results = final_state.get("task_results", {})
        compiled_details = ""
        results_df = pd.DataFrame() # Default to empty DataFrame


        if final_state.get("error"):
             compiled_details = f"Workflow Error: {final_state['error']}\n\n"
             # Attempt to display any partial results if available
             if task_results:
                 compiled_details += "Partial Results:\n"
                 for filename, doc_results in task_results.items():
                      compiled_details += f"\n📄 Document: {filename}\n"
                      if isinstance(doc_results, dict) and "error" in doc_results:
                           compiled_details += f"  ❌ Error: {doc_results['error']}\n"
                      elif isinstance(doc_results, dict):
                          for task_name, task_result in doc_results.items():
                              if isinstance(task_result, dict):
                                   response_preview = task_result.get('response', 'No response')
                                   compiled_details += f"  🎯 {task_name}: {response_preview[:100]}...\n"
                              else:
                                  compiled_details += f"  🎯 {task_name}: Error - {task_result}\n"
                      else:
                           compiled_details += f"  ❌ Document processing failed: {doc_results}\n"


             # If there are partial results that can be put in a DataFrame, try that
             if task_results or final_state.get("processed_filenames"): # Include ingestion errors
                  try:
                     csv_rows = []
                     # Add ingestion errors first
                     failed_ingestion_info = [f for f in final_state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
                     for fail_info in failed_ingestion_info:
                          csv_rows.append({
                            "filename": fail_info.get("filename", "unknown"),
                            "task": "ingestion_failed",
                            "response": fail_info.get("error", "Unknown ingestion error"),
                            "context_chunks_used": 0,
                            "input_tokens": 0,
                            "output_tokens": 0,
                            "total_tokens": 0
                        })

                     # Add task processing results/errors
                     for filename, doc_results in task_results.items():
                         if isinstance(doc_results, dict):
                            if "error" in doc_results:
                                csv_rows.append({
                                    "filename": filename,
                                    "task": "processing_failed",
                                    "response": doc_results["error"],
                                    "context_chunks_used": 0,
                                    "input_tokens": 0,
                                    "output_tokens": 0,
                                    "total_tokens": 0
                                })
                            else:
                                for task_name, task_result in doc_results.items():
                                    if isinstance(task_result, dict):
                                        token_usage = task_result.get("token_usage", {})
                                        csv_rows.append({
                                            "filename": filename,
                                            "task": task_name,
                                            "response": task_result.get("response", ""),
                                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                                            "input_tokens": token_usage.get("input_tokens", 0),
                                            "output_tokens": token_usage.get("output_tokens", 0),
                                            "total_tokens": token_usage.get("total_tokens", 0)
                                        })
                                    else:
                                         csv_rows.append({
                                            "filename": filename,
                                            "task": task_name,
                                            "response": f"Task error: {task_result}",
                                            "context_chunks_used": 0,
                                            "input_tokens": 0,
                                            "output_tokens": 0,
                                            "total_tokens": 0
                                        })
                         else:
                              csv_rows.append({
                                    "filename": filename,
                                    "task": "processing_failed",
                                    "response": f"Document processing failed: {doc_results}",
                                    "context_chunks_used": 0,
                                    "input_tokens": 0,
                                    "output_tokens": 0,
                                    "total_tokens": 0
                                })


                     if csv_rows:
                        results_df = pd.DataFrame(csv_rows)
                        column_order = [
                            "filename", "task", "response",
                            "context_chunks_used", "input_tokens",
                            "output_tokens", "total_tokens"
                        ]
                        existing_columns = [col for col in column_order if col in results_df.columns]
                        results_df = results_df[existing_columns]
                        for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                             if token_col in results_df.columns:
                                results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)

                  except Exception as csv_e:
                       compiled_details += f"\nError compiling partial results DataFrame: {csv_e}"
                       results_df = pd.DataFrame({"Error": [f"Could not compile results: {csv_e}"]}) # Indicate error in DataFrame too


             return status, compiled_details, results_df # Return DataFrame even on error


        # Compile results for display if no workflow-level error
        if task_results:
            compiled_details = ""
            csv_rows = [] # Prepare data for DataFrame display

            # Add ingestion errors first
            failed_ingestion_info = [f for f in final_state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
            for fail_info in failed_ingestion_info:
                 csv_rows.append({
                    "filename": fail_info.get("filename", "unknown"),
                    "task": "ingestion_failed",
                    "response": fail_info.get("error", "Unknown ingestion error"),
                    "context_chunks_used": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "total_tokens": 0
                })


            for filename, results in task_results.items():
                 compiled_details += f"\n📄 DOCUMENT: {filename}\n" + "─" * 50 + "\n"
                 if isinstance(results, dict) and "error" in results:
                     compiled_details += f"❌ Document processing failed: {results['error']}\n"
                     csv_rows.append({
                        "filename": filename,
                        "task": "processing_failed",
                        "response": results["error"],
                        "context_chunks_used": 0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "total_tokens": 0
                    })
                 elif isinstance(results, dict):
                    for task_name, task_result in results.items():
                        if isinstance(task_result, dict):
                            response = task_result.get("response", "No response")
                            chunks_used = task_result.get("context_chunks_used", 0)
                            token_usage = task_result.get("token_usage", {})

                            compiled_details += f"\n🎯 {task_name.upper()}:\n"
                            compiled_details += f"   📝 Response: {response[:200]}...\n" # Limit display length
                            compiled_details += f"   📚 Chunks used: {chunks_used}\n"
                            if token_usage:
                                compiled_details += f"   🔢 Tokens: {token_usage.get('total_tokens', 0)}\n"

                            csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": response,
                                "context_chunks_used": chunks_used,
                                "input_tokens": token_usage.get("input_tokens", 0),
                                "output_tokens": token_usage.get("output_tokens", 0),
                                "total_tokens": token_usage.get("total_tokens", 0)
                            })
                        else:
                             compiled_details += f"\n🎯 {task_name.upper()}:\n"
                             compiled_details += f"   ❌ Task Error: {task_result}\n"
                             csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": f"Task error: {task_result}",
                                "context_chunks_used": 0,
                                "input_tokens": 0,
                                "output_tokens": 0,
                                "total_tokens": 0
                            })
                 else:
                     compiled_details += f"❌ Document processing failed with non-dict result: {results}\n"
                     csv_rows.append({
                        "filename": filename,
                        "task": "processing_failed_unknown",
                        "response": f"Processing failed with unexpected result type: {results}",
                        "context_chunks_used": 0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "total_tokens": 0
                    })


            # Create DataFrame for the summary output
            df = pd.DataFrame(csv_rows) # Create DataFrame even if empty

            if not df.empty:
                results_df = df
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in results_df.columns]
                results_df = results_df[existing_columns]
                for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                     if token_col in results_df.columns:
                        results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)

            else:
                 results_df = pd.DataFrame({"Status": ["No task results or errors generated."]})


            # Display token usage summary in the details output as well
            compiled_details += "\n\n📊 TOTAL TOKEN USAGE SUMMARY FOR THIS RUN:\n"
            compiled_details += f"🔢 Total LLM Calls: {token_tracker.call_count}\n"
            compiled_details += f"📥 Total Input Tokens: {token_tracker.total_input_tokens:,}\n"
            compiled_details += f"📤 Total Output Tokens: {token_tracker.total_output_tokens:,}\n"
            compiled_details += f"🎯 Grand Total Tokens: {token_tracker.total_tokens:,}\n"
            if token_tracker.call_count > 0:
                compiled_details += f"📊 Average per call: {token_tracker.total_tokens/token_tracker.call_count:.1f} tokens\n"


            return status, compiled_details, results_df # Return DataFrame for display

        else:
            # If task_results is empty but no workflow error, it means ingestion might have failed for all
            ingestion_errors = [f for f in final_state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
            if ingestion_errors:
                compiled_details = "Ingestion failed for the uploaded file:\n"
                for err_info in ingestion_errors:
                     compiled_details += f"- {err_info.get('filename', 'unknown')}: {err_info.get('error', 'Unknown error')}\n"
                return status, compiled_details, pd.DataFrame({"Status": ["Ingestion Failed"]})

            return status, "No task results generated and no specific errors reported.", pd.DataFrame() # Return empty DataFrame if no task results

    except Exception as e:
        error_msg = f"An unexpected workflow execution failed: {e}"
        print(f"❌ {error_msg}")
        # Attempt to capture state before the crash if possible
        debug_details = f"An unexpected error occurred: {e}"
        if 'final_state' in locals() and final_state:
             debug_details += f"\nPartial workflow state available. Status: {final_state.get('workflow_status', 'unknown')}"
             if final_state.get('error'):
                  debug_details += f"\nInternal state error: {final_state['error']}"
             if final_state.get('task_results'):
                  debug_details += f"\nPartial task results available for {len(final_state['task_results'])} documents."
             if final_state.get('processed_filenames'):
                   debug_details += f"\nProcessed filenames state: {final_state['processed_filenames']}"


        # Attempt to create a simple error DataFrame
        error_df = pd.DataFrame({"Workflow Error": [error_msg], "Details": [debug_details]})

        return "Workflow Failed", debug_details, error_df


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced RAG-Based Expense Claims Processing")
    gr.Markdown("Upload your expense claim document (PDF, JPG, PNG, etc.) to extract details.")

    file_input = gr.File(label="Upload Expense Document", file_types=[".pdf", ".jpg", ".jpeg", ".png", ".tiff"])
    process_button = gr.Button("Process Document")

    status_output = gr.Textbox(label="Processing Status", max_lines=5)
    details_output = gr.Textbox(label="Extracted Details (Raw)", interactive=False, lines=20)
    summary_output = gr.Dataframe(label="Extracted Details (Structured)") # Use Dataframe for structured output

    process_button.click(
        process_document_ui,
        inputs=file_input,
        outputs=[status_output, details_output, summary_output]
    )

# Launch the Gradio interface
demo.launch()

**Reasoning**:
The traceback indicates that `dataclass` is not defined. Looking back at the history, `dataclass` is imported from the `dataclasses` module in the first code cell of the notebook. This confirms that the execution environment has been reset and previous imports are lost. I need to re-import the necessary modules, specifically `dataclasses`, and then re-execute the code block that defines all the classes and functions, including the global instances and the Gradio interface.



In [None]:
# Re-import necessary modules
import subprocess
import time
import os
import json
import pandas as pd
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, asdict # Re-import dataclass
import warnings
import uuid
# Re-import other modules used in the definitions if necessary, based on previous cell imports
import chromadb
from unstructured.partition.auto import partition # Needed for EnhancedOCRProcessor
from langchain_ollama import ChatOllama # Needed for EnhancedRAGExpenseProcessor
import ollama # Needed for embedding and testing connection

# Re-execute code from previous cells to ensure all necessary components are defined.
# This includes TokenUsageTracker, ClaimDocument, FilenameBasedDocumentManager,
# EnhancedOCRProcessor, EnhancedIsolatedVectorStore, EnhancedExpenseTaskManager,
# EnhancedRAGExpenseProcessor, and create_enhanced_rag_workflow.

# ================================
# STEP 2: TOKEN USAGE TRACKING (Re-definition)
# ================================

class TokenUsageTracker:
    """Track token usage across all LLM calls"""

    def __init__(self):
        self.call_history = []
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.total_tokens = 0
        self.call_count = 0

    def track_call(self, operation: str, filename: str, task: str, response):
        """Track a single LLM call and extract usage info"""

        usage_info = {
            "operation": operation,
            "filename": filename,
            "task": task,
            "timestamp": datetime.now().isoformat(),
            "input_tokens": 0,
            "output_tokens": 0,
            "total_tokens": 0,
            "duration_ms": 0
        }

        # Extract token usage from response
        try:
            if hasattr(response, 'usage_metadata') and response.usage_metadata:
                usage_info["input_tokens"] = response.usage_metadata.get('input_tokens', 0)
                usage_info["output_tokens"] = response.usage_metadata.get('output_tokens', 0)
                usage_info["total_tokens"] = response.usage_metadata.get('total_tokens', 0)

            # Fallback: try response_metadata
            elif hasattr(response, 'response_metadata') and response.response_metadata:
                metadata = response.response_metadata
                usage_info["input_tokens"] = metadata.get('prompt_eval_count', 0)
                usage_info["output_tokens"] = metadata.get('eval_count', 0)
                usage_info["total_tokens"] = usage_info["input_tokens"] + usage_info["output_tokens"]
                usage_info["duration_ms"] = metadata.get('total_duration', 0) // 1000000  # Convert to ms

        except Exception as e:
            print(f"⚠️ Could not extract token usage: {e}")

        # Update totals
        self.total_input_tokens += usage_info["input_tokens"]
        self.total_output_tokens += usage_info["output_tokens"]
        self.total_tokens += usage_info["total_tokens"]
        self.call_count += 1

        # Store call history
        self.call_history.append(usage_info)

        # Print usage info
        self.print_usage_info(usage_info)

        return usage_info


    def print_usage_info(self, usage_info: Dict[str, Any]):
        """Print formatted usage information"""
        print(f"📊 TOKEN USAGE - {usage_info['operation']} | {usage_info['filename']} | {usage_info['task']}")
        print(f"   📥 Input: {usage_info['input_tokens']} tokens")
        print(f"   📤 Output: {usage_info['output_tokens']} tokens")
        print(f"   🔢 Total: {usage_info['total_tokens']} tokens")
        if usage_info['duration_ms'] > 0:
            print(f"   ⏱️ Duration: {usage_info['duration_ms']}ms")
        print()

    def print_summary(self):
        """Print overall token usage summary"""
        print("=" * 60)
        print("📊 TOTAL TOKEN USAGE SUMMARY")
        print("=" * 60)
        print(f"🔢 Total LLM Calls: {self.call_count}")
        print(f"📥 Total Input Tokens: {self.total_input_tokens:,}")
        print(f"📤 Total Output Tokens: {self.total_output_tokens:,}")
        print(f"🎯 Grand Total Tokens: {self.total_tokens:,}")

        if self.call_count > 0:
            print(f"📊 Average per call: {self.total_tokens/self.call_count:.1f} tokens")
        print()

# Global token tracker
token_tracker = TokenUsageTracker()

print("\n✅ SETUP COMPLETE - TokenUsageTracker!")
print("=" * 70)

# ================================
# STEP 3: FILENAME-BASED DOCUMENT MANAGEMENT (Re-definition)
# ================================

@dataclass # Use the re-imported dataclass
class ClaimDocument:
    """Document with filename-based identification"""
    filename: str  # Primary identifier (no more UUIDs!)
    file_path: str
    raw_text: str
    chunks: List[str]
    metadata: Dict[str, Any]
    processed_timestamp: datetime

class FilenameBasedDocumentManager:
    """Manages documents using filenames as primary identifiers"""

    def __init__(self):
        self.documents_registry = {}  # filename -> ClaimDocument
        self.chunk_to_file_map = {}  # chunk_id -> filename

    def register_document(self, file_path: str, raw_text: str) -> str:
        """Register document using filename as ID"""

        filename = Path(file_path).stem  # Get filename without extension

        print(f"📋 REGISTERING DOCUMENT: {filename}")
        print(f"   📁 Source: {Path(file_path).name}")
        print(f"   📄 Text length: {len(raw_text)} characters")

        # Create isolated chunks for this document
        chunks = self.create_document_chunks(raw_text, filename)

        claim_doc = ClaimDocument(
            filename=filename,
            file_path=file_path,
            raw_text=raw_text,
            chunks=chunks,
            metadata={
                "file_name": Path(file_path).name,
                "file_extension": Path(file_path).suffix,
                "chunk_count": len(chunks),
                "source": "ocr_extraction"
            },
            processed_timestamp=datetime.now()
        )

        self.documents_registry[filename] = claim_doc

        # Update chunk mapping
        for i, chunk in enumerate(chunks):
            chunk_id = f"{filename}_chunk_{i}"
            self.chunk_to_file_map[chunk_id] = filename

        print(f"✅ Document registered: {filename} with {len(chunks)} chunks")
        return filename

    def create_document_chunks(self, text: str, filename: str) -> List[str]:
        """Create chunks with filename-specific context isolation"""

        print(f"🔪 CHUNKING DOCUMENT: {filename}")

        lines = text.split('\n')
        chunks = []
        current_chunk = []
        current_length = 0
        max_chunk_size = 500

        # Expense document section markers
        section_markers = [
            'total', 'amount', 'date', 'vendor', 'receipt', 'invoice',
            'item', 'quantity', 'price', 'tax', 'subtotal'
        ]

        for line in lines:
            line = line.strip()
            if not line:
                continue

            line_length = len(line)
            is_section_start = any(marker in line.lower() for marker in section_markers)

            if (current_length + line_length > max_chunk_size) or \
               (is_section_start and current_chunk and current_length > 200):

                chunk_text = '\n'.join(current_chunk)
                if chunk_text.strip():
                    # Add filename isolation metadata to chunk
                    isolated_chunk = f"[DOCUMENT: {filename}]\n{chunk_text}"
                    chunks.append(isolated_chunk)

                current_chunk = [line]
                current_length = line_length
            else:
                current_chunk.append(line)
                current_length += line_length + 1

        # Add final chunk
        if current_chunk:
            chunk_text = '\n'.join(current_chunk)
            if chunk_text.strip():
                isolated_chunk = f"[DOCUMENT: {filename}]\n{chunk_text}"
                chunks.append(isolated_chunk)

        print(f"   🔪 Created {len(chunks)} chunks (avg {len(text)//len(chunks) if chunks else 0} chars each)")
        return chunks



    def get_document_context(self, filename: str) -> Optional[ClaimDocument]:
        """Get complete context for a specific document"""
        return self.documents_registry.get(filename)

    def list_all_documents(self) -> List[str]:
        """List all registered filenames"""
        return list(self.documents_registry.keys())

print("\n✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!")
print("=" * 70)

# ================================
# STEP 4: ENHANCED OCR PROCESSOR (Re-definition)
# ================================

class EnhancedOCRProcessor:
    """OCR processing with detailed progress tracking"""

    def __init__(self):
        self.supported_formats = ['.pdf', '.jpg', '.jpeg', '.png', '.tiff']

    def extract_text_from_document(self, file_path: str) -> str:
        """Extract text with detailed progress tracking"""

        filename = Path(file_path).name
        print(f"🔍 EXTRACTING TEXT FROM: {filename}")
        print(f"   📁 Full path: {file_path}")
        print(f"   📊 File size: {Path(file_path).stat().st_size / 1024:.1f} KB")

        try:
            # Ensure partition is imported from unstructured.partition.auto
            from unstructured.partition.auto import partition

            print(f"   🔄 Processing with UnstructuredIO...")

            # Process document with UnstructuredIO
            elements = partition(filename=file_path)

            print(f"   📋 Found {len(elements)} document elements")

            # Extract text from all elements
            full_text = ""
            for i, element in enumerate(elements):
                if hasattr(element, 'text') and element.text:
                    full_text += element.text + "\n"
                    if i < 5:  # Show first few elements
                        print(f"     Element {i+1}: {element.text[:50]}...")

            # Clean and normalize text
            full_text = self.clean_extracted_text(full_text)

            print(f"   ✅ Extracted {len(full_text)} characters")
            print(f"   📝 Text preview: {full_text[:100]}...")
            return full_text

        except Exception as e:
            print(f"   ❌ OCR extraction failed: {e}")
            return ""

    def clean_extracted_text(self, text: str) -> str:
        """Clean extracted text with progress info"""
        if not text:
            return ""

        original_length = len(text)
        lines = text.split('\n')
        cleaned_lines = []

        for line in lines:
            line = line.strip()
            if line and len(line) > 2:
                cleaned_lines.append(line)

        cleaned_text = '\n'.join(cleaned_lines)
        print(f"   🧹 Cleaned: {original_length} → {len(cleaned_text)} chars ({len(cleaned_lines)} lines)")

        return cleaned_text

print("\n✅ SETUP COMPLETE - FILENAME-BASED DOCUMENT MANAGEMENT!")
print("=" * 70)

# ================================
# STEP 5: ENHANCED VECTOR STORE (Re-definition)
# ================================

class EnhancedIsolatedVectorStore:
    """ChromaDB with enhanced tracking and filename-based isolation"""

    def __init__(self, embedding_model: str = "nomic-embed-text"): # Use default or passed model
        import chromadb # Ensure chromadb is imported

        self.embedding_model = embedding_model

        print(f"🗄️ INITIALIZING VECTOR STORE")
        print(f"   🤖 Embedding Model: {embedding_model}")

        # Initialize ChromaDB client using new API
        self.client = chromadb.PersistentClient(path="./chroma_db")

        # Create collection
        self.collection = self.client.get_or_create_collection(
            name="filename_based_expense_claims",
            metadata={"hnsw:space": "cosine"}
        )

        print(f"   ✅ ChromaDB initialized")

    def embed_text(self, text: str, filename: str = "unknown") -> List[float]:
        """Generate embeddings with progress tracking"""

        print(f"🔢 GENERATING EMBEDDING: {filename}")
        print(f"   📝 Text length: {len(text)} chars")

        try:
            # Ensure ollama is imported and available
            import ollama
            response = ollama.embeddings(model=self.embedding_model, prompt=text)
            embedding = response['embedding']
            print(f"   ✅ Generated {len(embedding)}-dimensional embedding")
            return embedding
        except Exception as e:
            print(f"   ❌ Embedding error: {e}")
            return []

    def add_document_chunks(self, filename: str, chunks: List[str], metadata: Dict[str, Any]):
        """Add chunks for a specific document with detailed tracking"""

        print(f"📚 ADDING CHUNKS TO VECTOR STORE: {filename}")
        print(f"   📊 Number of chunks: {len(chunks)}")

        embeddings = []
        chunk_ids = []
        metadatas = []

        for i, chunk in enumerate(chunks):
            print(f"   🔄 Processing chunk {i+1}/{len(chunks)}")

            # Generate embedding
            embedding = self.embed_text(chunk, f"{filename}_chunk_{i}")
            if not embedding:
                print(f"   ⚠️ Skipping chunk {i+1} - no embedding generated")
                continue

            chunk_id = f"{filename}_chunk_{i}"
            chunk_metadata = {
                **metadata,
                "filename": filename,
                "chunk_index": i,
                "chunk_id": chunk_id,
                "isolated": True
            }

            embeddings.append(embedding)
            chunk_ids.append(chunk_id)
            metadatas.append(chunk_metadata)

        # Add to ChromaDB
        if embeddings:
            self.collection.add(
                embeddings=embeddings,
                documents=chunks,
                metadatas=metadatas,
                ids=chunk_ids
            )

            print(f"   ✅ Added {len(embeddings)} chunks to vector store")
        else:
            print(f"   ❌ No chunks added - all embeddings failed")

    def query_document_specific(self, query: str, filename: str, n_results: int = 3) -> Dict[str, Any]:
        """Query specific document only - prevents cross-contamination"""

        print(f"🔍 QUERYING VECTOR STORE: {filename}")
        print(f"   ❓ Query: {query}")
        print(f"   📊 Requesting {n_results} results")

        query_embedding = self.embed_text(query, f"query_{filename}")
        if not query_embedding:
            return {"error": "Failed to generate query embedding"}

        # Query with filename filter to ensure isolation
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=n_results,
            where={"filename": filename},  # CRITICAL: Isolates to specific document
            include=["documents", "metadatas", "distances"]
        )

        print(f"   ✅ Found {len(results['documents'][0]) if results['documents'] else 0} relevant chunks")

        return {
            "documents": results['documents'][0] if results['documents'] else [],
            "metadatas": results['metadatas'][0] if results['metadatas'] else [],
            "distances": results['distances'][0] if results['distances'] else [],
            "filename": filename
        }

    def get_collection_stats(self) -> Dict[str, Any]:
        """Get detailed statistics about stored documents"""

        print("📊 GENERATING COLLECTION STATISTICS")

        count = self.collection.count()

        # Get unique filenames
        all_metadata = self.collection.get(include=["metadatas"])
        filenames = set()
        if all_metadata and all_metadata['metadatas']: # Added check for all_metadata existence
            for meta in all_metadata['metadatas']:
                if 'filename' in meta:
                    filenames.add(meta['filename'])

        stats = {
            "total_chunks": count,
            "unique_documents": len(filenames),
            "filenames": list(filenames)
        }

        print(f"   📚 Total chunks: {stats['total_chunks']}")
        print(f"   📄 Unique documents: {stats['unique_documents']}")
        print(f"   📝 Documents: {', '.join(stats['filenames'])}")

        return stats

print("\n✅ SETUP COMPLETE - ENHANCED VECTOR STORE!")
print("=" * 70)

# ================================
# STEP 6: ENHANCED EXPENSE TASK MANAGER (Re-definition)
# ================================

class EnhancedExpenseTaskManager:
    """Manages predefined expense extraction tasks with better tracking"""

    def __init__(self):
        self.predefined_tasks = {
            "extract_amount": {
                "query": "total amount due payment cost price sum money dollar",
                "description": "Extract the total amount from this expense document",
                "expected_format": "numeric value with currency"
            },
            "extract_date": {
                "query": "date transaction purchase invoice receipt timestamp when",
                "description": "Extract the date from this expense document",
                "expected_format": "date in YYYY-MM-DD format"
            },
            "extract_vendor": {
                "query": "vendor merchant company business supplier store restaurant hotel",
                "description": "Extract vendor/merchant name from this expense document",
                "expected_format": "company or business name"
            },
            "extract_category": {
                "query": "category type classification expense kind service product item",
                "description": "Determine expense category from this document",
                "expected_format": "expense category classification"
            },
            "extract_items": {
                "query": "items products services line items purchases description details",
                "description": "Extract itemized details from this expense document",
                "expected_format": "list of items or services"
            },
            "extract_tax": {
                "query": "tax VAT GST sales tax tax rate percentage",
                "description": "Extract tax information from this expense document",
                "expected_format": "tax amount and rate"
            }
        }

    def get_task_info(self, task_name: str) -> Dict[str, str]:
        """Get complete task information"""
        return self.predefined_tasks.get(task_name, {})

    def list_available_tasks(self) -> List[str]:
        """List all available extraction tasks"""
        return list(self.predefined_tasks.keys())

print("\n✅ SETUP COMPLETE - ENHANCED EXPENSE TASK MANAGER!")
print("=" * 70)


# ================================
# STEP 7: ENHANCED RAG PROCESSOR (Re-definition)
# ================================

class EnhancedRAGExpenseProcessor:
    """RAG-based expense processor with comprehensive tracking"""

    def __init__(self, text_model: str = "gemma3:1b"): # Use default or passed model
        from langchain_ollama import ChatOllama # Ensure ChatOllama is imported

        print(f"🚀 INITIALIZING RAG EXPENSE PROCESSOR")
        print(f"   🤖 Text Model: {text_model}")

        self.llm = ChatOllama(
            model=text_model,
            temperature=0.1,
            base_url="http://127.0.0.1:11434"
        )

        self.vector_store = EnhancedIsolatedVectorStore()
        self.task_manager = EnhancedExpenseTaskManager()
        self.document_manager = FilenameBasedDocumentManager()
        self.ocr_processor = EnhancedOCRProcessor()

        print("   ✅ All components initialized")

    def ingest_document(self, file_path: str) -> str:
        """INGESTION PHASE: Process document and store in vector DB"""

        filename = Path(file_path).name
        print("\n" + "="*70)
        print(f"🔄 INGESTION PHASE STARTING")
        print(f"📄 FILE: {filename}")
        print(f"📁 PATH: {file_path}")
        print("="*70)

        # Step 1: OCR extraction
        raw_text = self.ocr_processor.extract_text_from_document(file_path)
        if not raw_text:
            print("❌ INGESTION FAILED: No text extracted")
            return None

        # Step 2: Register document with filename-based system
        filename_id = self.document_manager.register_document(file_path, raw_text)

        # Step 3: Get document context
        document = self.document_manager.get_document_context(filename_id)

        # Step 4: Store in vector database
        metadata = {
            **document.metadata,
            "ingestion_timestamp": datetime.now().isoformat()
        }

        self.vector_store.add_document_chunks(
            filename=filename_id,
            chunks=document.chunks,
            metadata=metadata
        )

        print(f"✅ INGESTION COMPLETED: {filename_id}")
        print("="*70)
        return filename_id

    def process_expense_task(self, filename: str, task_name: str) -> Dict[str, Any]:
        """INFERENCE PHASE: Process specific task for document"""

        print(f"\n🎯 INFERENCE PHASE STARTING")
        print(f"📄 DOCUMENT: {filename}")
        print(f"🎯 TASK: {task_name}")
        print("-" * 50)

        # Step 1: Get task information
        task_info = self.task_manager.get_task_info(task_name)
        if not task_info:
            return {"error": f"Unknown task: {task_name}"}

        task_query = task_info.get("query", "")
        task_description = task_info.get("description", "")

        print(f"📋 Task Description: {task_description}")
        print(f"🔍 Search Query: {task_query}")

        # Step 2: Retrieve relevant chunks (ISOLATED to this document)
        retrieval_results = self.vector_store.query_document_specific(
            query=task_query,
            filename=filename,
            n_results=3
        )

        if retrieval_results.get("error"):
            return retrieval_results

        # Step 3: Prepare optimized context
        context = self.optimize_context(retrieval_results, task_name)

        # Step 4: Generate response with LLM (WITH TOKEN TRACKING)
        response_text, token_usage = self.generate_task_response_with_tracking(
            context, task_name, task_description, filename
        )

        result = {
            "task": task_name,
            "filename": filename,
            "response": response_text,
            "context_chunks_used": len(retrieval_results["documents"]),
            "token_usage": token_usage
        }

        print(f"✅ INFERENCE COMPLETED: {task_name} for {filename}")
        print("-" * 50)

        return result

    def optimize_context(self, retrieval_results: Dict[str, Any], task_name: str) -> str:
        """CONTEXT OPTIMIZATION: Reduce context overloading"""

        documents = retrieval_results.get("documents", [])
        distances = retrieval_results.get("distances", [])
        filename = retrieval_results.get("filename", "unknown")

        print(f"🔧 OPTIMIZING CONTEXT: {filename}")
        print(f"   📊 Raw chunks: {len(documents)}")

        if not documents:
            return "No relevant context found"

        # Rank documents by relevance
        doc_scores = list(zip(documents, distances))
        doc_scores.sort(key=lambda x: x[1])

        optimized_chunks = []
        total_length = 0
        max_context_length = 1500

        for i, (doc, score) in enumerate(doc_scores):
            # Remove document prefix from chunks
            clean_doc = doc.replace(f"[DOCUMENT: {filename}]\n", "")

            if total_length + len(clean_doc) <= max_context_length:
                optimized_chunks.append(clean_doc)
                total_length += len(clean_doc)
                print(f"   ✅ Chunk {i+1}: {len(clean_doc)} chars (relevance: {score:.3f})")
            else:
                remaining_space = max_context_length - total_length
                if remaining_space > 100:
                    truncated = clean_doc[:remaining_space] + "..."
                    optimized_chunks.append(truncated)
                    print(f"   ✂️ Chunk {i+1}: truncated to {len(truncated)} chars")
                break

        context = "\n\n---\n\n".join(optimized_chunks)
        print(f"   🎯 Final context: {len(context)} chars from {len(optimized_chunks)} chunks")

        return context

    def generate_task_response_with_tracking(self, context: str, task_name: str, task_description: str, filename: str) -> Tuple[str, Dict[str, Any]]:
        """Generate LLM response with token usage tracking"""

        print(f"🤖 GENERATING LLM RESPONSE: {task_name} | {filename}")

        prompt = f"""You are an expert expense analyst. {task_description}

CONTEXT FROM EXPENSE DOCUMENT ({filename}):
{context}

TASK: {task_name}
INSTRUCTION: {task_description}

Based ONLY on the context provided above, extract the requested information. Be precise and factual. If the information is not clearly present in the context, state "Information not found in provided context."

Response:"""

        print(f"   📝 Prompt length: {len(prompt)} characters")

        try:
            response = self.llm.invoke(prompt)

            # Track token usage
            token_usage = token_tracker.track_call("llm_inference", filename, task_name, response)

            return response.content.strip(), token_usage

        except Exception as e:
            error_msg = f"Error generating response: {e}"
            print(f"   ❌ {error_msg}")
            return error_msg, {}

    def process_all_tasks_for_document(self, filename: str) -> Dict[str, Any]:
        """Process all predefined tasks for a document"""

        print(f"\n📊 PROCESSING ALL TASKS FOR: {filename}")
        print("="*50)

        tasks = self.task_manager.list_available_tasks()
        results = {}

        for i, task in enumerate(tasks, 1):
            print(f"\n[{i}/{len(tasks)}] Starting task: {task}")
            result = self.process_expense_task(filename, task)
            results[task] = result

        print(f"\n✅ ALL TASKS COMPLETED FOR: {filename}")
        return results

print("\n✅ SETUP COMPLETE - ENHANCED RAG PROCESSOR!")
print("=" * 70)


# ================================
# STEP 8: ENHANCED WORKFLOW (Re-definition)
# ================================

from langgraph.graph import StateGraph # Ensure StateGraph is imported
from typing import TypedDict

class EnhancedRAGWorkflowState(TypedDict):
    """Enhanced state for RAG workflow"""
    file_paths: List[str]
    current_file_index: int
    processed_filenames: List[str]
    current_filename: str
    task_results: Dict[str, Dict[str, Any]]
    workflow_status: str
    error: Optional[str]

def create_enhanced_rag_workflow(processor: EnhancedRAGExpenseProcessor) -> StateGraph:
    """Create enhanced LangGraph workflow for RAG processing"""

    def enhanced_ingestion_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced ingestion with detailed tracking"""

        print("\n" + "🔄 WORKFLOW: ENHANCED INGESTION PHASE STARTING")
        print("="*70)

        file_paths = state.get("file_paths", [])
        processed_filenames = []

        if not file_paths:
            state["workflow_status"] = "ingestion_failed"
            state["error"] = "No file paths provided for ingestion."
            print("❌ INGESTION FAILED: No file paths provided.")
            return state


        for i, file_path in enumerate(file_paths, 1):
            print(f"\n[{i}/{len(file_paths)}] Processing file: {Path(file_path).name}")

            try:
                # Use the processor to ingest the document
                filename = processor.ingest_document(file_path)
                if filename:
                    processed_filenames.append(filename)
                    print(f"✅ Successfully ingested: {filename}")
                else:
                    print(f"❌ Failed to ingest: {Path(file_path).name}")

            except Exception as e:
                print(f"❌ Error ingesting {Path(file_path).name}: {e}")
                state["error"] = str(e) # Store the error in state
                # Decide if you want to stop on first ingestion error or continue
                # For now, let's continue to process other files if possible
                processed_filenames.append({"error": str(e), "filename": Path(file_path).name})


        state["processed_filenames"] = processed_filenames
        state["workflow_status"] = "ingestion_complete" if any(isinstance(f, str) for f in processed_filenames) else "ingestion_failed" # Check if at least one file was successfully processed

        print(f"\n📊 INGESTION PHASE COMPLETED")
        successful_count = sum(1 for f in processed_filenames if isinstance(f, str))
        print(f"   ✅ Successfully processed: {successful_count} files")
        print(f"   ❌ Failed: {len(file_paths) - successful_count} files")


        return state

    def enhanced_task_processing_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced task processing with detailed tracking"""

        print("\n🎯 WORKFLOW: ENHANCED TASK PROCESSING PHASE STARTING")
        print("="*70)

        # Filter out ingestion errors before processing tasks
        processable_filenames = [f for f in state.get("processed_filenames", []) if isinstance(f, str)]
        task_results = state.get("task_results", {}) # Initialize or get existing results

        if not processable_filenames:
            state["workflow_status"] = "processing_skipped"
            print("⚠️ TASK PROCESSING SKIPPED: No documents successfully ingested.")
            return state


        for i, filename in enumerate(processable_filenames, 1):
            print(f"\n[{i}/{len(processable_filenames)}] Processing tasks for: {filename}")

            try:
                results = processor.process_all_tasks_for_document(filename)
                task_results[filename] = results
                print(f"✅ Completed all tasks for: {filename}")

            except Exception as e:
                print(f"❌ Error processing tasks for {filename}: {e}")
                task_results[filename] = {"error": str(e)}
                state["error"] = str(e) # Store the error in state


        state["task_results"] = task_results
        state["workflow_status"] = "processing_complete"

        print(f"\n📊 TASK PROCESSING PHASE COMPLETED")
        print(f"   📄 Documents processed: {len(task_results)}")

        return state

    def enhanced_results_compilation_node(state: EnhancedRAGWorkflowState) -> EnhancedRAGWorkflowState:
        """Enhanced results compilation with detailed stats and CSV export"""

        print("\n📊 WORKFLOW: ENHANCED RESULTS COMPILATION STARTING")
        print("="*70)

        task_results = state.get("task_results", {})
        # token_tracker is a global instance assumed to be available
        # processor instance (and its vector_store) is also assumed to be available

        if not task_results:
             state["workflow_status"] = "compilation_skipped"
             print("⚠️ RESULTS COMPILATION SKIPPED: No task results to compile.")
             token_tracker.print_summary() # Print summary even if compilation skipped
             return state


        # Compile detailed statistics
        total_documents_attempted_ingestion = len(state.get("file_paths", [])) # Count original files
        successfully_ingested_filenames = [f for f in state.get("processed_filenames", []) if isinstance(f, str)]
        total_documents_successfully_processed = len(task_results) # Count documents with task results (successful or not)


        # Save results with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_file = f"enhanced_rag_expense_results_{timestamp}.json"
        csv_file = f"enhanced_rag_expense_results_{timestamp}.csv"

        # Create comprehensive results package
        comprehensive_results = {
            "timestamp": timestamp,
            "summary": {
                "total_documents_attempted_ingestion": total_documents_attempted_ingestion,
                "successfully_ingested_documents": len(successfully_ingested_filenames),
                "documents_with_task_results": total_documents_successfully_processed,
                "failed_ingestion": total_documents_attempted_ingestion - len(successfully_ingested_filenames)

            },
            "token_usage_summary": {
                "total_calls": token_tracker.call_count,
                "total_input_tokens": token_tracker.total_input_tokens,
                "total_output_tokens": token_tracker.total_output_tokens, # Fixed typo here
                "total_tokens": token_tracker.total_tokens
            },
            "document_results": task_results,
            "token_call_history": token_tracker.call_history,
            "initial_state": state.get("initial_state", {}) # Include initial state for debugging
        }

        # Save JSON results
        try:
            with open(results_file, 'w') as f:
                json.dump(comprehensive_results, f, indent=2, default=str)
            print(f"💾 JSON RESULTS SAVED TO: {results_file}")
        except Exception as e:
            print(f"❌ Error saving JSON results: {e}")
            state["error"] = f"Error saving JSON results: {e}"


        # Create CSV from results
        csv_rows = []

        # Add rows for documents that failed ingestion
        failed_ingestion_info = [f for f in state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
        for fail_info in failed_ingestion_info:
             csv_rows.append({
                "filename": fail_info.get("filename", "unknown"),
                "task": "ingestion_failed",
                "response": fail_info.get("error", "Unknown ingestion error"),
                "context_chunks_used": 0,
                "input_tokens": 0,
                "output_tokens": 0,
                "total_tokens": 0
            })


        for filename, doc_results in task_results.items():
            if "error" in doc_results:
                # Add error row for document that failed task processing after successful ingestion
                csv_rows.append({
                    "filename": filename,
                    "task": "processing_failed",
                    "response": doc_results["error"],
                    "context_chunks_used": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "total_tokens": 0
                })
            else:
                # Process each task for this document
                for task_name, task_result in doc_results.items():
                    if isinstance(task_result, dict):
                        token_usage = task_result.get("token_usage", {})
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": task_result.get("response", ""),
                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                            "input_tokens": token_usage.get("input_tokens", 0),
                            "output_tokens": token_usage.get("output_tokens", 0),
                            "total_tokens": token_usage.get("total_tokens", 0)
                        })
                    else:
                         # Handle task-specific errors (if task_result is not a dict but an error string)
                        csv_rows.append({
                            "filename": filename,
                            "task": task_name,
                            "response": f"Task error: {task_result}",
                            "context_chunks_used": 0,
                            "input_tokens": 0,
                            "output_tokens": 0,
                            "total_tokens": 0
                        })


        # Save CSV
        df = pd.DataFrame(csv_rows) # Create DataFrame even if empty

        if not df.empty: # Check if DataFrame is not empty before processing
            try:
                # Reorder columns for better readability - handle missing columns gracefully
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in df.columns]
                df = df[existing_columns]


                # Convert token columns to numeric, handling errors
                for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                     if token_col in df.columns:
                        df[token_col] = pd.to_numeric(df[token_col], errors='coerce').fillna(0).astype(int)


                # Save to CSV
                df.to_csv(csv_file, index=False, encoding='utf-8')
                print(f"💾 CSV RESULTS SAVED TO: {csv_file}")

                # Display summary statistics from CSV
                print(f"\n📊 CSV Summary:")
                print(f"   📄 Total rows: {len(df)}")
                # Ensure 'filename' column exists before calling nunique
                if 'filename' in df.columns:
                    print(f"   📁 Documents: {df['filename'].nunique()}")
                     # Handle case where no tasks were processed successfully
                    successful_task_rows = df[~df['task'].isin(['ingestion_failed', 'processing_failed', 'workflow_error'])]
                    if not successful_task_rows.empty and 'filename' in successful_task_rows.columns:
                         print(f"   🎯 Avg tasks per successful doc: {successful_task_rows.groupby('filename').size().mean():.1f}")
                    else:
                        print("   🎯 Avg tasks per successful doc: N/A (No successful tasks)")

                # Ensure 'total_tokens' column exists and is numeric before summing
                if 'total_tokens' in df.columns:
                    try:
                        total_tokens_sum = df['total_tokens'].sum()
                        print(f"   🔢 Total tokens used: {total_tokens_sum:,}")
                         # Update comprehensive_results with the sum from CSV if needed
                        comprehensive_results['token_usage_summary']['total_tokens_from_csv'] = total_tokens_sum

                    except Exception as e:
                         print(f"⚠️ Could not calculate total tokens from CSV: {e}")
                else:
                    print("⚠️ 'total_tokens' column not found in CSV.")

            except Exception as e:
                print(f"❌ Error processing or saving CSV results: {e}")
                state["error"] = f"Error processing or saving CSV results: {e}"

        else:
             print("⚠️ No CSV rows generated. Skipping CSV save.")


        # Also save a summary CSV with aggregated data per document
        summary_csv_file = f"enhanced_rag_expense_summary_{timestamp}.csv"
        summary_rows = []

        # Include documents that failed ingestion in summary
        for fail_info in failed_ingestion_info:
            summary_rows.append({
                "filename": fail_info.get("filename", "unknown"),
                "status": "Ingestion Failed",
                "error_message": fail_info.get("error", "Unknown error")
            })


        for filename, doc_results in task_results.items():
            if "error" not in doc_results:
                row = {"filename": filename, "status": "Processed"}

                # Extract key information from each task
                for task_name in ["extract_amount", "extract_date", "extract_vendor",
                                "extract_category", "extract_tax", "extract_items", "extract_tax"]: # Include all relevant tasks
                    if task_name in doc_results and isinstance(doc_results[task_name], dict):
                        response = doc_results[task_name].get("response", "")
                        # Clean the response (take first line or first 100 chars)
                        cleaned = response.split('\n')[0][:100] if response else ""
                        row[task_name] = cleaned
                    elif task_name in doc_results:
                         # Handle task-specific errors
                         row[task_name] = f"Error: {doc_results[task_name]}"
                    else:
                        row[task_name] = "Task Not Run"


                # Add token totals
                total_tokens = sum(
                    doc_results.get(task, {}).get("token_usage", {}).get("total_tokens", 0)
                    for task in doc_results if isinstance(doc_results.get(task), dict)
                )
                row["total_tokens_used"] = total_tokens

                summary_rows.append(row)
            else:
                 # Add document that failed task processing after ingestion
                 summary_rows.append({
                     "filename": filename,
                     "status": "Task Processing Failed",
                     "error_message": doc_results["error"]
                 })


        if summary_rows:
            try:
                summary_df = pd.DataFrame(summary_rows)
                # Ensure 'status' and 'error_message' columns exist and are placed early
                summary_column_order = ["filename", "status", "error_message"] + [col for col in summary_df.columns if col not in ["filename", "status", "error_message"]]
                summary_df = summary_df.get(summary_column_order, summary_df) # Use .get to handle missing columns

                summary_df.to_csv(summary_csv_file, index=False, encoding='utf-8')
                print(f"💾 SUMMARY CSV SAVED TO: {summary_csv_file}")
            except Exception as e:
                print(f"❌ Error saving summary CSV: {e}")
                state["error"] = f"Error saving summary CSV: {e}"


        print(f"\n📊 FILES SAVED:")
        if os.path.exists(results_file): print(f"   📄 Detailed JSON: {results_file}")
        if os.path.exists(csv_file) and not df.empty: print(f"   📄 Detailed CSV: {csv_file}")
        if os.path.exists(summary_csv_file) and summary_rows: print(f"   📄 Summary CSV: {summary_csv_file}")
        print(f"   📄 Documents processed successfully (ingestion+tasks): {total_documents_successfully_processed}/{total_documents_attempted_ingestion}")

        # Print token usage summary
        token_tracker.print_summary()

        state["workflow_status"] = "complete"
        return state


    # Build enhanced workflow
    workflow = StateGraph(EnhancedRAGWorkflowState)

    workflow.add_node("enhanced_ingestion", enhanced_ingestion_node)
    workflow.add_node("enhanced_task_processing", enhanced_task_processing_node)
    workflow.add_node("enhanced_results_compilation", enhanced_results_compilation_node)

    workflow.set_entry_point("enhanced_ingestion")
    workflow.add_edge("enhanced_ingestion", "enhanced_task_processing")
    workflow.add_edge("enhanced_task_processing", "enhanced_results_compilation")
    workflow.set_finish_point("enhanced_results_compilation")

    return workflow.compile()

print("\n✅ SETUP COMPLETE - ENHANCED WORKFLOW!")
print("=" * 70)


# Initialize the RAG processor and workflow globally for efficiency in Gradio
# This avoids re-initializing the processor (and thus ChromaDB client/collection)
# on every file upload in the Gradio app.

# Ensure the TokenUsageTracker is also a persistent instance
try:
    # Re-initialize token_tracker to clear previous runs' data if desired,
    # or let it accumulate total usage across all runs.
    # For this demo, let's re-initialize for per-file/per-run tracking clarity in UI.
    token_tracker = TokenUsageTracker()
    processor = EnhancedRAGExpenseProcessor()
    # Create the workflow instance
    enhanced_rag_workflow = create_enhanced_rag_workflow(processor)
    print("\n✅ RAG Processor and Workflow initialized globally.")
except Exception as e:
    print(f"\n❌ Error initializing RAG components globally: {e}")
    processor = None # Set to None if initialization fails
    enhanced_rag_workflow = None # Set to None if initialization fails


# Update the Gradio UI function to call the workflow
def process_document_ui(file):
    """Function to receive the uploaded file, run the workflow, and display results."""
    # Use the globally initialized workflow and processor
    global enhanced_rag_workflow, processor, token_tracker # Added token_tracker here

    if file is None:
        return "Please upload a file.", "No file uploaded.", pd.DataFrame() # Return empty DataFrame

    file_path = file.name # Gradio provides the temporary path here

    status = f"Received file: {file_path}. Starting RAG workflow..."
    print(status) # Print status to console

    # Re-initialize token tracker for this specific run to track usage per file upload
    # Check if token_tracker was initialized globally before resetting
    if 'token_tracker' in globals() and isinstance(token_tracker, TokenUsageTracker):
         token_tracker = TokenUsageTracker()
         print("\n📊 Token usage tracker reset for new upload.")
    else:
         # If global token_tracker wasn't initialized, create one
         token_tracker = TokenUsageTracker()
         print("\n📊 Token usage tracker initialized for new upload.")



    # Check if workflow initialization was successful
    if enhanced_rag_workflow is None:
         error_msg = "RAG Workflow failed to initialize during startup. Cannot process file."
         print(f"❌ {error_msg}")
         return error_msg, "RAG components failed to initialize. Check server logs.", pd.DataFrame()


    # Execute enhanced workflow with the uploaded file path
    initial_state = {
        "file_paths": [file_path], # Pass the uploaded file path as a list
        "current_file_index": 0, # Not strictly used with single file processing, but keep for state structure
        "processed_filenames": [],
        "current_filename": "", # Not strictly used with single file processing
        "task_results": {},
        "workflow_status": "initialized",
        "error": None,
        "initial_state": {"file_paths": [file_path]} # Store initial state for compilation node
    }

    final_state = None # Initialize final_state to None
    try:
        print(f"Executing workflow with state: {initial_state}")
        final_state = enhanced_rag_workflow.invoke(initial_state)
        workflow_final_status = final_state.get('workflow_status', 'unknown')
        status = f"Workflow completed with status: {workflow_final_status}"
        print(status) # Print final status
        print(f"Final state: {final_state}") # Print final state for debugging


        # Process final state to display results
        task_results = final_state.get("task_results", {})
        compiled_details = ""
        results_df = pd.DataFrame() # Default to empty DataFrame


        if final_state.get("error"):
             compiled_details = f"Workflow Error: {final_state['error']}\n\n"
             # Attempt to display any partial results if available
             if task_results or final_state.get("processed_filenames"): # Include ingestion errors
                 compiled_details += "Partial Results:\n"
                 # Add ingestion errors first
                 failed_ingestion_info = [f for f in final_state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
                 for fail_info in failed_ingestion_info:
                     compiled_details += f"\n📄 Document: {fail_info.get('filename', 'unknown')}\n"
                     compiled_details += f"  ❌ Ingestion Error: {fail_info.get('error', 'Unknown ingestion error')}\n"


                 # Add task processing results/errors
                 for filename, doc_results in task_results.items():
                      compiled_details += f"\n📄 Document: {filename}\n"
                      if isinstance(doc_results, dict) and "error" in doc_results:
                           compiled_details += f"  ❌ Task Processing Error: {doc_results['error']}\n"
                      elif isinstance(doc_results, dict):
                          for task_name, task_result in doc_results.items():
                              if isinstance(task_result, dict):
                                   response_preview = task_result.get('response', 'No response')
                                   compiled_details += f"  🎯 {task_name}: {response_preview[:100]}...\n"
                              else:
                                  compiled_details += f"  🎯 {task_name}: Error - {task_result}\n"
                      else:
                           compiled_details += f"  ❌ Document processing failed: {doc_results}\n"


             # If there are partial results that can be put in a DataFrame, try that
             if task_results or final_state.get("processed_filenames"): # Include ingestion errors
                  try:
                     csv_rows = []
                     # Add ingestion errors first
                     failed_ingestion_info = [f for f in final_state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
                     for fail_info in failed_ingestion_info:
                          csv_rows.append({
                            "filename": fail_info.get("filename", "unknown"),
                            "task": "ingestion_failed",
                            "response": fail_info.get("error", "Unknown ingestion error"),
                            "context_chunks_used": 0,
                            "input_tokens": 0,
                            "output_tokens": 0,
                            "total_tokens": 0
                        })

                     # Add task processing results/errors
                     for filename, doc_results in task_results.items():
                         if isinstance(doc_results, dict):
                            if "error" in doc_results:
                                csv_rows.append({
                                    "filename": filename,
                                    "task": "processing_failed",
                                    "response": doc_results["error"],
                                    "context_chunks_used": 0,
                                    "input_tokens": 0,
                                    "output_tokens": 0,
                                    "total_tokens": 0
                                })
                            else:
                                for task_name, task_result in doc_results.items():
                                    if isinstance(task_result, dict):
                                        token_usage = task_result.get("token_usage", {})
                                        csv_rows.append({
                                            "filename": filename,
                                            "task": task_name,
                                            "response": task_result.get("response", ""),
                                            "context_chunks_used": task_result.get("context_chunks_used", 0),
                                            "input_tokens": token_usage.get("input_tokens", 0),
                                            "output_tokens": token_usage.get("output_tokens", 0),
                                            "total_tokens": token_usage.get("total_tokens", 0)
                                        })
                                    else:
                                         csv_rows.append({
                                            "filename": filename,
                                            "task": task_name,
                                            "response": f"Task error: {task_result}",
                                            "context_chunks_used": 0,
                                            "input_tokens": 0,
                                            "output_tokens": 0,
                                            "total_tokens": 0
                                        })
                         else:
                              csv_rows.append({
                                    "filename": filename,
                                    "task": "processing_failed",
                                    "response": f"Document processing failed: {doc_results}",
                                    "context_chunks_used": 0,
                                    "input_tokens": 0,
                                    "output_tokens": 0,
                                    "total_tokens": 0
                                })


                     if csv_rows:
                        results_df = pd.DataFrame(csv_rows)
                        column_order = [
                            "filename", "task", "response",
                            "context_chunks_used", "input_tokens",
                            "output_tokens", "total_tokens"
                        ]
                        existing_columns = [col for col in column_order if col in results_df.columns]
                        results_df = results_df[existing_columns]
                        for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                             if token_col in results_df.columns:
                                results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)

                  except Exception as csv_e:
                       compiled_details += f"\nError compiling partial results DataFrame: {csv_e}"
                       results_df = pd.DataFrame({"Error": [f"Could not compile results: {csv_e}"]}) # Indicate error in DataFrame too


             return status, compiled_details, results_df # Return DataFrame even on error


        # Compile results for display if no workflow-level error
        if task_results or final_state.get("processed_filenames"): # Include ingestion errors even if no task results
            compiled_details = ""
            csv_rows = [] # Prepare data for DataFrame display

            # Add ingestion errors first
            failed_ingestion_info = [f for f in final_state.get("processed_filenames", []) if isinstance(f, dict) and "error" in f]
            for fail_info in failed_ingestion_info:
                 csv_rows.append({
                    "filename": fail_info.get("filename", "unknown"),
                    "task": "ingestion_failed",
                    "response": fail_info.get("error", "Unknown ingestion error"),
                    "context_chunks_used": 0,
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "total_tokens": 0
                })
                 compiled_details += f"\n📄 DOCUMENT: {fail_info.get('filename', 'unknown')}\n"
                 compiled_details += f"  ❌ Ingestion Failed: {fail_info.get('error', 'Unknown ingestion error')}\n"


            for filename, results in task_results.items():
                 compiled_details += f"\n📄 DOCUMENT: {filename}\n" + "─" * 50 + "\n"
                 if isinstance(results, dict) and "error" in results:
                     compiled_details += f"❌ Document processing failed: {results['error']}\n"
                     csv_rows.append({
                        "filename": filename,
                        "task": "processing_failed",
                        "response": results["error"],
                        "context_chunks_used": 0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "total_tokens": 0
                    })
                 elif isinstance(results, dict):
                    for task_name, task_result in results.items():
                        if isinstance(task_result, dict):
                            response = task_result.get("response", "No response")
                            chunks_used = task_result.get("context_chunks_used", 0)
                            token_usage = task_result.get("token_usage", {})

                            compiled_details += f"\n🎯 {task_name.upper()}:\n"
                            compiled_details += f"   📝 Response: {response[:200]}...\n" # Limit display length
                            compiled_details += f"   📚 Chunks used: {chunks_used}\n"
                            if token_usage:
                                compiled_details += f"   🔢 Tokens: {token_usage.get('total_tokens', 0)}\n"

                            csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": response,
                                "context_chunks_used": chunks_used,
                                "input_tokens": token_usage.get("input_tokens", 0),
                                "output_tokens": token_usage.get("output_tokens", 0),
                                "total_tokens": token_usage.get("total_tokens", 0)
                            })
                        else:
                             compiled_details += f"\n🎯 {task_name.upper()}:\n"
                             compiled_details += f"   ❌ Task Error: {task_result}\n"
                             csv_rows.append({
                                "filename": filename,
                                "task": task_name,
                                "response": f"Task error: {task_result}",
                                "context_chunks_used": 0,
                                "input_tokens": 0,
                                "output_tokens": 0,
                                "total_tokens": 0
                            })
                 else:
                     compiled_details += f"❌ Document processing failed with non-dict result: {results}\n"
                     csv_rows.append({
                        "filename": filename,
                        "task": "processing_failed_unknown",
                        "response": f"Processing failed with unexpected result type: {results}",
                        "context_chunks_used": 0,
                        "input_tokens": 0,
                        "output_tokens": 0,
                        "total_tokens": 0
                    })


            # Create DataFrame for the summary output
            df = pd.DataFrame(csv_rows) # Create DataFrame even if empty

            if not df.empty:
                results_df = df
                column_order = [
                    "filename", "task", "response",
                    "context_chunks_used", "input_tokens",
                    "output_tokens", "total_tokens"
                ]
                existing_columns = [col for col in column_order if col in results_df.columns]
                results_df = results_df[existing_columns]
                for token_col in ['input_tokens', 'output_tokens', 'total_tokens']:
                     if token_col in results_df.columns:
                        results_df[token_col] = pd.to_numeric(results_df[token_col], errors='coerce').fillna(0).astype(int)

            else:
                 results_df = pd.DataFrame({"Status": ["No task results or errors generated."]})


            # Display token usage summary in the details output as well
            compiled_details += "\n\n📊 TOTAL TOKEN USAGE SUMMARY FOR THIS RUN:\n"
            compiled_details += f"🔢 Total LLM Calls: {token_tracker.call_count}\n"
            compiled_details += f"📥 Total Input Tokens: {token_tracker.total_input_tokens:,}\n"
            compiled_details += f"📤 Total Output Tokens: {token_tracker.total_output_tokens:,}\n"
            compiled_details += f"🎯 Grand Total Tokens: {token_tracker.total_tokens:,}\n"
            if token_tracker.call_count > 0:
                compiled_details += f"📊 Average per call: {token_tracker.total_tokens/token_tracker.call_count:.1f} tokens\n"


            return status, compiled_details, results_df # Return DataFrame for display

        else:
            # If task_results is empty and no ingestion errors, something went wrong early
            return status, "Workflow completed but generated no results or errors.", pd.DataFrame()


    except Exception as e:
        error_msg = f"An unexpected workflow execution failed: {e}"
        print(f"❌ {error_msg}")
        # Attempt to capture state before the crash if possible
        debug_details = f"An unexpected error occurred: {e}"
        if final_state: # Check if final_state object was created before the crash
             debug_details += f"\nPartial workflow state available. Status: {final_state.get('workflow_status', 'unknown')}"
             if final_state.get('error'):
                  debug_details += f"\nInternal state error: {final_state['error']}"
             if final_state.get('task_results'):
                  debug_details += f"\nPartial task results available for {len(final_state['task_results'])} documents."
             if final_state.get('processed_filenames'):
                   debug_details += f"\nProcessed filenames state: {final_state['processed_filenames']}"


        # Attempt to create a simple error DataFrame
        error_df = pd.DataFrame({"Workflow Error": [error_msg], "Details": [debug_details]})

        return "Workflow Failed", debug_details, error_df


# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Enhanced RAG-Based Expense Claims Processing")
    gr.Markdown("Upload your expense claim document (PDF, JPG, PNG, etc.) to extract details.")

    file_input = gr.File(label="Upload Expense Document", file_types=[".pdf", ".jpg", ".jpeg", ".png", ".tiff"])
    process_button = gr.Button("Process Document")

    status_output = gr.Textbox(label="Processing Status", max_lines=5)
    details_output = gr.Textbox(label="Extracted Details (Raw)", interactive=False, lines=20)
    summary_output = gr.Dataframe(label="Extracted Details (Structured)") # Use Dataframe for structured output

    process_button.click(
        process_document_ui,
        inputs=file_input,
        outputs=[status_output, details_output, summary_output]
    )

# Launch the Gradio interface
demo.launch()

## Display results

### Subtask:
Present the extracted expense details and processing status to the user on the web page.


**Reasoning**:
Review the `process_document_ui` function and the Gradio interface definition to ensure the extracted data (status, raw text/logs, and DataFrame results) from the workflow's `final_state` is correctly passed to the respective Gradio output components (`status_output`, `details_output`, `summary_output`) and that the interface is launched. This directly addresses steps 1-5 of the subtask instructions.



In [None]:
# The code from the previous step already implements the logic to extract
# processing status, compiled raw text/logs/errors, and the structured
# results DataFrame from the workflow's final_state and return them.
# It also already maps these return values to the status_output,
# details_output, and summary_output components in the Gradio interface
# definition. The demo.launch() call is also present.

# We just need to ensure the Gradio interface definition correctly
# uses these components and the process_button click event maps
# the outputs correctly.

# Review of the Gradio interface definition:
# with gr.Blocks() as demo:
#     gr.Markdown(...)
#     gr.Markdown(...)
#
#     file_input = gr.File(...)
#     process_button = gr.Button(...)
#
#     status_output = gr.Textbox(label="Processing Status", max_lines=5)
#     details_output = gr.Textbox(label="Extracted Details (Raw)", interactive=False, lines=20)
#     summary_output = gr.Dataframe(label="Extracted Details (Structured)")
#
#     process_button.click(
#         process_document_ui,
#         inputs=file_input,
#         outputs=[status_output, details_output, summary_output]
#     )
#
# demo.launch()

# The outputs list [status_output, details_output, summary_output]
# correctly aligns with the return order of process_document_ui:
# return status, compiled_details, results_df

# status_output is a Textbox, suitable for the status string.
# details_output is a Textbox with interactive=False and multiple lines,
# suitable for displaying the compiled raw text, logs, and errors.
# summary_output is a Dataframe, suitable for displaying the pandas DataFrame results.

# The demo.launch() call will make the interface accessible.

# No changes are needed to the code as the previous step already completed
# the requirements of this subtask.

print("Review complete: The Gradio interface is configured to display results.")
print("The process_document_ui function extracts the necessary data and maps it correctly.")
print("The Gradio interface includes components for status, raw details, and a structured DataFrame.")
print("The interface is launched using demo.launch().")

## Add error handling

### Subtask:
Implement error handling to provide feedback to the user if something goes wrong during upload or processing.


**Reasoning**:
Review the existing code to identify potential error points and ensure they are handled within the workflow nodes and the Gradio UI function to provide user feedback.



In [None]:
# Review the existing code, focusing on error handling.

# In enhanced_ingestion_node:
# - Catches exceptions during processor.ingest_document and stores in state["error"].
# - Adds {"error": ..., "filename": ...} to processed_filenames list for failed ingests.
# - Sets workflow_status to "ingestion_failed" if no files are successfully processed.
# - Returns the updated state.

# In enhanced_task_processing_node:
# - Filters out ingestion errors from processed_filenames.
# - Catches exceptions during processor.process_all_tasks_for_document and stores in state["error"].
# - Adds {"error": ...} for a document that failed task processing after successful ingestion.
# - Sets workflow_status to "processing_skipped" if no processable files exist.
# - Returns the updated state.

# In enhanced_results_compilation_node:
# - Catches exceptions during JSON and CSV saving attempts and stores in state["error"].
# - Includes ingestion and task processing errors in the generated CSV and summary CSV data.
# - Returns the updated state.

# In process_document_ui:
# - Checks for file is None.
# - Catches exceptions during EnhancedRAGExpenseProcessor initialization.
# - Catches exceptions during enhanced_rag_workflow.invoke().
# - Checks final_state.get("error") after workflow invocation.
# - If final_state.get("error") is present, it constructs a detailed error message including partial results.
# - It attempts to create a DataFrame from partial results or indicates an error in the DataFrame.
# - It returns appropriate status, details, and DataFrame outputs in case of errors at various stages.

# The existing code already incorporates significant error handling as requested in the subtask instructions.
# The workflow nodes capture and propagate errors via the state["error"] field and by marking individual document results with errors.
# The process_document_ui function checks for the state["error"] after the workflow run and provides user-friendly feedback in the Gradio outputs.
# Try-except blocks are present around critical initialization and workflow invocation calls.

print("Review complete: Error handling mechanisms are present in the workflow nodes and the UI function.")
print("The code attempts to capture and display errors during ingestion, task processing, compilation, initialization, and workflow execution.")
print("Partial results and error messages are included in the Gradio outputs when errors occur.")

## Containerize the application (optional but recommended)

### Subtask:
Containerize the application using Docker for easier deployment.


**Reasoning**:
Create the Dockerfile to containerize the application.



In [None]:
# Create Dockerfile
dockerfile_content = """
# Use an official Python runtime as a parent image
FROM python:3.10-slim

# Set the working directory in the container
WORKDIR /app

# Install system dependencies needed for UnstructuredIO (like libmagic-dev, poppler-utils, tesseract-ocr)
# These might vary based on the exact needs of UnstructuredIO and its dependencies.
# Based on UnstructuredIO documentation and common requirements for PDF/image processing:
RUN apt-get update && apt-get install -y \\
    libpq-dev \\
    build-essential \\
    poppler-utils \\
    tesseract-ocr \\
    libtesseract-dev \\
    libleptonica-dev \\
    libmagic-dev \\
    # Additional dependencies for specific file types if needed, e.g., libxml2-dev for HTML
    # libxml2-dev \\
    # libxslt1-dev \\
    # libjpeg-dev \\
    # libpng-dev \\
    # zlib1g-dev \\
    # swig \\
    # libcurl4-openssl-dev \\
    # libssl-dev \\
    pkg-config \\
    curl \\
    gnupg \\
    # Ensure ollama is available (download script)
    # This curl command is from ollama.com/install.sh
    # It downloads the install script, which we then execute.
    # Note: Running install scripts in Dockerfile can be less transparent.
    # An alternative is to download the binary directly if possible.
    # For simplicity and following the original install method, we use the script.
    # We need to ensure the script installs ollama in a location accessible by the app.
    # The default install location is /usr/local, which is fine.
    && curl -fsSL https://ollama.com/install.sh | sh \\
    && rm -rf /var/lib/apt/lists/*

# Copy the current directory contents into the container at /app
COPY . /app

# Install any needed dependencies specified in requirements.txt
# Create a requirements.txt file based on the installed packages
# In a real scenario, you would generate this from your environment (e.g., pip freeze > requirements.txt)
# For this example, we'll list the main packages explicitly.
RUN echo "ollama" >> requirements.txt && \\
    echo "langchain-ollama" >> requirements.txt && \\
    echo "langchain-core" >> requirements.txt && \\
    echo "langchain-community" >> requirements.txt && \\
    echo "chromadb>=0.4.0" >> requirements.txt && \\
    echo "unstructured[pdf]>=0.10.0" >> requirements.txt && \\
    echo "sentence-transformers" >> requirements.txt && \\
    echo "pandas" >> requirements.txt && \\
    echo "pillow" >> requirements.txt && \\
    echo "python-dateutil" >> requirements.txt && \\
    echo "pydantic" >> requirements.txt && \\
    echo "langgraph" >> requirements.txt && \\
    echo "gradio" >> requirements.txt && \\
    echo "aiofiles" >> requirements.txt && \\
    echo "fsspec" >> requirements.txt && \\
    echo "protobuf>=3.20.2,<5.0.0dev" >> requirements.txt && \\
    echo "google-auth>=2.38.0" >> requirements.txt && \\
    echo "notebook" >> requirements.txt && \\
    echo "requests" >> requirements.txt && \\
    echo "tornado" >> requirements.txt && \\
    echo "google-api-core>=1.14.0,<2.0.0dev" >> requirements.txt && \\
    echo "google-cloud-automl" >> requirements.txt && \\
    echo "google-cloud-bigquery" >> requirements.txt && \\
    echo "google-cloud-bigquery-storage" >> requirements.txt && \\
    echo "google-cloud-translate" >> requirements.txt && \\
    echo "google-cloud-aiplatform" >> requirements.txt && \\
    echo "google-cloud-language" >> requirements.txt && \\
    echo "google-cloud-storage" >> requirements.txt && \\
    echo "google-cloud-vision" >> requirements.txt && \\
    echo "grpcio" >> requirements.txt && \\
    echo "grpcio-status" >> requirements.txt && \\
    echo "numpy" >> requirements.txt && \\
    echo "scipy" >> requirements.txt && \\
    echo "tqdm" >> requirements.txt && \\
    echo "layoutparser[ocr,layout]>=0.3.4" >> requirements.txt && \\
    echo "fastjsonschema" >> requirements.txt && \\
    echo "python-multipart" >> requirements.txt && \\
    echo "onnxruntime" >> requirements.txt && \\
    echo "onnx" >> requirements.txt && \\
    echo "fastapi" >> requirements.txt && \\
    echo "uvicorn" >> requirements.txt && \\
    echo "pydantic-settings" >> requirements.txt && \\
    echo "typing-extensions" >> requirements.txt && \\
    echo "regex" >> requirements.txt && \\
    echo "html2text" >> requirements.txt && \\
    echo "bs4" >> requirements.txt && \\
    echo "html-text" >> requirements.txt && \\
    echo "ftfy" >> requirements.txt && \\
    echo "spacy" >> requirements.txt && \\
    echo "pandas-stubs" >> requirements.txt && \\
    echo "lxml" >> requirements.txt && \\
    echo "tabulate" >> requirements.txt && \\
    echo "requests-toolbelt" >> requirements.txt && \\
    echo "more-itertools" >> requirements.txt && \\
    echo "dataclasses-json" >> requirements.txt && \\
    echo "jsonpath-ng" >> requirements.txt && \\
    echo "tenacity" >> requirements.txt && \\
    echo "typing-inspect" >> requirements.txt && \\
    echo "jsonpatch" >> requirements.txt && \\
    echo "jsonschema" >> requirements.txt && \\
    echo "markdown-it-py" >> requirements.txt && \\
    echo "mdurl" >> requirements.txt && \\
    echo "Pygments" >> requirements.txt && \\
    echo "rich" >> requirements.txt && \\
    echo "comm" >> requirements.txt && \\
    echo "jupyter_client" >> requirements.txt && \\
    echo "jupyter_core" >> requirements.txt && \\
    echo "nest_asyncio" >> requirements.txt && \\
    echo "packaging" >> requirements.txt && \\
    echo "platformdirs" >> requirements.txt && \\
    echo "psutil" >> requirements.txt && \\
    echo "pyzmq" >> requirements.txt && \\
    echo "send2trash" >> requirements.txt && \\
    echo "terminado" >> requirements.txt && \\
    echo "textdistance" >> requirements.txt && \\
    echo "traitlets" >> requirements.txt && \\
    echo "websockets" >> requirements.txt && \\
    echo "anyio" >> requirements.txt && \\
    echo "babel" >> requirements.txt && \\
    echo "comm" >> requirements.txt && \\
    echo "debugpy" >> requirements.txt && \\
    echo "jedi" >> requirements.txt && \\
    echo "jupyter_events" >> requirements.txt && \\
    echo "jupyter_server" >> requirements.txt && \\
    echo "jupyterlab_server" >> requirements.txt && \\
    echo "matplotlib-inline" >> requirements.txt && \\
    echo "mistune" >> requirements.txt && \\
    echo "nbformat" >> requirements.txt && \\
    echo "packaging" >> requirements.txt && \\
    echo "prometheus_client" >> requirements.txt && \\
    echo "prompt_toolkit" >> requirements.txt && \\
    echo "pygments" >> requirements.txt && \\
    echo "rpds-py" >> requirements.txt && \\
    echo "Send2Trash" >> requirements.txt && \\
    echo "stack_data" >> requirements.txt && \\
    echo "tinycss2" >> requirements.txt && \\
    echo "uri-template" >> requirements.txt && \\
    echo "wcwidth" >> requirements.txt && \\
    echo "webcolors" >> requirements.txt && \\
    echo "yarl" >> requirements.txt && \\
    echo "zipp" >> requirements.txt \\
    && pip install --no-cache-dir -r requirements.txt

# Ensure spacy models are downloaded if used by unstructured or other deps
RUN python -m spacy download en_core_web_sm || echo "Spacy model download failed, proceeding..."

# Expose the port that Gradio will run on
EXPOSE 7860

# Set the OLLAMA_HOST environment variable
ENV OLLAMA_HOST=127.0.0.1:11434

# Command to run the application
# Start Ollama server in the background and then run the Python app
# This requires ollama to be installed and in the PATH within the container.
# Use a script to handle startup order
COPY start.sh /app/start.sh
RUN chmod +x /app/start.sh

CMD ["/app/start.sh"]
"""

with open("Dockerfile", "w") as f:
    f.write(dockerfile_content)

# Create a start.sh script
start_script_content = """
#!/bin/bash

# Start Ollama server in the background
echo "Starting Ollama server..."
ollama serve &

# Wait for Ollama to be ready (optional, but good practice)
# You might need a more robust health check depending on your needs
sleep 10

# Download required Ollama models (if not already present)
echo "Downloading Ollama models..."
ollama pull nomic-embed-text || echo "nomic-embed-text pull failed, attempting to continue..."
ollama pull gemma3:1b || echo "gemma3:1b pull failed, attempting to continue..."

# Pre-download UnstructuredIO models by running a small python script
# This is needed because the python -m spacy download doesn't handle all unstructured deps.
echo "Pre-downloading UnstructuredIO models..."
python -c "from unstructured.partition.auto import partition; from unstructured.partition.pdf import partition_pdf; from unstructured.partition.utils.constants import Source; print('UnstructuredIO imports successful, models might be downloaded.'); import os; from PIL import Image; import io; img = Image.new('RGB', (100, 100), color='white'); img_bytes = io.BytesIO(); img.save(img_bytes, format='PNG'); img_bytes.seek(0); temp_path = '/tmp/dummy_image.png'; with open(temp_path, 'wb') as f: f.write(img_bytes.getvalue()); try: elements = partition(filename=temp_path); print('UnstructuredIO dummy partition successful, models should be downloaded.'); except Exception as e: print(f'UnstructuredIO dummy partition failed: {e}'); if os.path.exists(temp_path): os.remove(temp_path);" || echo "UnstructuredIO pre-download script failed, attempting to continue..."


# Run the Gradio application
echo "Starting Gradio application..."
python your_app_file_name.py # Replace 'your_app_file_name.py' with the actual name of your Python script
"""

with open("start.sh", "w") as f:
    f.write(start_script_content)

print("Dockerfile and start.sh created.")
print("Remember to replace 'your_app_file_name.py' in start.sh with the actual name of your main Python script.")

## Deploy the application

### Subtask:
Deploy the web application to a cloud platform or server.


**Reasoning**:
This subtask involves deploying the containerized application. The next logical step is to build the Docker image locally using the created Dockerfile and then tag it appropriately for pushing to a container registry. This is a preparatory step before pushing and deploying.



In [22]:
# Assume the Dockerfile and start.sh created in the previous step exist in the current directory.

# Define the image name and tag
image_name = "expense-rag-app"
image_tag = "latest"
full_image_name = f"{image_name}:{image_tag}"

print(f"🛠️ Building Docker image: {full_image_name}")

# Build the Docker image
# This command executes the Dockerfile in the current directory (.)
try:
    subprocess.run(["docker", "build", "-t", full_image_name, "."], check=True)
    print(f"✅ Docker image built successfully: {full_image_name}")
except subprocess.CalledProcessError as e:
    print(f"❌ Error building Docker image: {e}")
except FileNotFoundError:
    print("❌ Docker command not found. Please ensure Docker is installed and in your PATH.")
except Exception as e:
    print(f"❌ An unexpected error occurred during Docker build: {e}")


🛠️ Building Docker image: expense-rag-app:latest
❌ Docker command not found. Please ensure Docker is installed and in your PATH.


## Deploy the application

### Subtask:
Deploy the web application to a cloud platform or server.


## Deploy the application

### Subtask:
Deploy the web application to a cloud platform or server.


## Summary:

### Data Analysis Key Findings

*   Gradio was chosen as the web framework for the front end due to its suitability for file uploads and ease of use for rapid prototyping.
*   The Gradio interface was successfully designed and implemented with components for file upload, processing, and displaying results (status, raw details, and a structured DataFrame).
*   The Gradio interface correctly receives uploaded files and provides temporary file paths, which are then used as input for the RAG workflow.
*   The existing RAG workflow was modified to accept a list of file paths from the front end.
*   Error handling was integrated into the workflow nodes and the Gradio UI function to capture and display errors during ingestion, task processing, compilation, and workflow execution.
*   Comprehensive logging and status updates were added to track the workflow progress and results.
*   Detailed and summary results, including token usage, are compiled and intended to be displayed in the Gradio interface and saved as JSON and CSV files.
*   Attempts were made to containerize the application using Docker by creating a `Dockerfile` and `start.sh` script, but the process failed because the Docker environment was not available.
*   Deployment of the application to a cloud platform could not be completed due to the lack of a functional Docker environment.

### Insights or Next Steps

*   Ensure a functional Docker environment is available to successfully containerize and deploy the application.
*   Verify that all necessary dependent Python classes and functions (like `EnhancedRAGExpenseProcessor`, `TokenUsageTracker`, etc.) are correctly defined and accessible in the execution environment, especially when running outside the notebook or integrated development environment where they were initially developed.
