In [31]:
import json
import logging
from typing import List, Dict, Any, Optional, Tuple
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain.callbacks.manager import get_openai_callback
from langchain_core.prompts import ChatPromptTemplate

# Import your existing RAG fusion function
from rag_pipeline.rag_fusion_pipeline import *

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SmartRAGTool:
    def __init__(self, 
                 local_index_path: str,
                 embedding_model,
                 llm_params: Optional[Dict] = None):
        """
        Initialize the Smart RAG Tool
        
        Args:
            local_index_path: Path to the FAISS vector store
            embedding_model: Embedding model for vector store
            llm_params: Parameters for the LLM
        """
        self.local_index_path = local_index_path
        self.embedding_model = embedding_model
        self.llm_params = llm_params or {"temperature": 0, "model": "gpt-4o"}
        self.llm = ChatOpenAI(**self.llm_params)
        
        # Function definition for OpenAI function calling
        self.function_definition = {
            "name": "search_knowledge_base",
            "description": "Search the knowledge base for specific information when the question requires domain-specific or detailed factual information that may not be in general knowledge",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "The search query to find relevant information"
                    },
                    "mode": {
                        "type": "string",
                        "enum": ["original", "generated"],
                        "description": "Search mode: 'original' uses only the user query, 'generated' creates multiple related queries for better coverage"
                    },
                    "num_queries": {
                        "type": "integer",
                        "minimum": 1,
                        "maximum": 10,
                        "description": "Number of queries to generate if using 'generated' mode (default: 3)"
                    }
                },
                "required": ["query"]
            }
        }
        
        # Prompt to determine if RAG is needed
        self.decision_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an AI assistant that decides whether a user question requires searching a knowledge base or can be answered with general knowledge.

Use the search_knowledge_base function ONLY when:
1. The questions is specific and realted to a policy, fact, legal specification.
2. The question requires current or specific information that might not be in general knowledge

DO NOT use the search function for:
1. General knowledge questions (e.g., "What is machine learning?", "How does photosynthesis work?")
2. Questions not related to the knowledge questions should be politely redirected to ask user for questions that are related to the knowledge base.

If you decide to search, choose the appropriate mode:
- Use "original" mode for simple, direct queries
- Use "generated" mode for complex questions that might benefit from multiple search perspectives

If you don't need to search, answer the question directly using your general knowledge."""),
            ("user", "{user_query}")
        ])

    def search_knowledge_base(self, query: str, mode: str = "generated", num_queries: int = 3) -> Dict[str, Any]:
        """
        Search the knowledge base using RAG fusion
        
        Args:
            query: Search query
            mode: Search mode ('original' or 'generated')
            num_queries: Number of queries to generate if using 'generated' mode
            
        Returns:
            Dictionary with answer and metadata
        """
        try:
            logger.info(f"Searching knowledge base with query: '{query}' in {mode} mode")
            
            answer, metadata = rag_fusion_answer(
                user_query=query,
                local_index_path=self.local_index_path,
                embedding_model=self.embedding_model,
                mode=mode,
                num_generated_queries=num_queries,
                top_k=5,  # Retrieve more documents for better context
                params=self.llm_params
            )
            
            return {
                "answer": answer,
                "metadata": metadata,
                "search_performed": True
            }
            
        except Exception as e:
            logger.error(f"Error in knowledge base search: {str(e)}")
            return {
                "answer": f"I encountered an error while searching the knowledge base: {str(e)}",
                "metadata": {},
                "search_performed": False,
                "error": str(e)
            }

    def process_user_query(self, user_query: str, chat_context: Optional[str] = None) -> Tuple[str, Dict[str, Any]]:
        """
        Process a user query and decide whether to use RAG or answer directly
        
        Args:
            user_query: The user's question
            chat_context: Optional conversation context
            
        Returns:
            Tuple of (answer, metadata)
        """
        try:
            # Create the chain with function calling
            chain = self.decision_prompt | self.llm.bind(
                functions=[self.function_definition],
                function_call="auto"
            )
            
            metadata = {
                "user_query": user_query,
                "decision_made": None,
                "search_performed": False,
                "total_cost": 0.0,
                "token_usage": {}
            }
            
            with get_openai_callback() as cb:
                response = chain.invoke({"user_query": user_query})
            
            # Track decision-making cost
            decision_cost = {
                "total_tokens": cb.total_tokens,
                "prompt_tokens": cb.prompt_tokens,
                "completion_tokens": cb.completion_tokens,
                "total_cost": cb.total_cost
            }
            metadata["decision_cost"] = decision_cost
            metadata["total_cost"] += cb.total_cost
            
            # Check if the model decided to use function calling
            if hasattr(response, 'additional_kwargs') and 'function_call' in response.additional_kwargs:
                function_call = response.additional_kwargs['function_call']
                function_name = function_call['name']
                function_args = json.loads(function_call['arguments'])
                
                logger.info(f"LLM decided to use function: {function_name} with args: {function_args}")
                metadata["decision_made"] = "search_needed"
                
                if function_name == "search_knowledge_base":
                    # Execute the RAG search
                    search_result = self.search_knowledge_base(
                        query=function_args.get('query', user_query),
                        mode=function_args.get('mode', 'generated'),
                        num_queries=function_args.get('num_queries', 3)
                    )
                    
                    if search_result.get("search_performed"):
                        metadata.update(search_result["metadata"])
                        metadata["search_performed"] = True
                        metadata["total_cost"] += search_result["metadata"].get("total_price", 0)
                        
                        return search_result["answer"], metadata
                    else:
                        # Fallback if search failed
                        return f"I tried to search for information but encountered an issue. Based on general knowledge: I'd be happy to help, but I may need more specific information to give you the most accurate answer.", metadata
                        
            else:
                # LLM decided not to search - use the direct response
                logger.info("LLM decided no search needed, providing direct answer")
                metadata["decision_made"] = "direct_answer"
                return response.content, metadata
                
        except Exception as e:
            logger.error(f"Error in process_user_query: {str(e)}")
            metadata["error"] = str(e)
            return f"I encountered an error while processing your question: {str(e)}", metadata

    def chat(self, user_query: str, chat_context: Optional[str] = None, verbose: bool = False) -> str:
        """
        Simple chat interface that handles the query and returns just the answer
        
        Args:
            user_query: The user's question
            chat_context: Optional conversation context
            verbose: Whether to print detailed metadata
            
        Returns:
            The answer string
        """
        answer, metadata = self.process_user_query(user_query, chat_context)
        
        if verbose:
            print(f"\n--- Smart RAG Tool Execution Report ---")
            print(f"User Query: {user_query}")
            print(f"Decision Made: {metadata.get('decision_made', 'unknown')}")
            print(f"Search Performed: {metadata.get('search_performed', False)}")
            print(f"Total Cost: ${metadata.get('total_cost', 0):.4f}")
            
            if metadata.get('search_performed'):
                token_usage = metadata.get('token_usage', {})
                print(f"Total Tokens Used: {token_usage.get('total_tokens', 0)}")
                print(f"Queries Used: {metadata.get('queries_used', [])}")
                print(f"Documents Retrieved: {metadata.get('num_documents_retrieved', 0)}")
            
            print(f"--- End Report ---\n")
        
        return answer


# Usage Example
def create_smart_rag_tool(local_index_path: str, embedding_model) -> SmartRAGTool:
    """Factory function to create a Smart RAG Tool instance"""
    return SmartRAGTool(
        local_index_path=local_index_path,
        embedding_model=embedding_model,
        llm_params={"temperature": 0, "model": "gpt-4o"}
    )


# Example usage:
    # Initialize the tool
    # smart_rag = create_smart_rag_tool("./faiss_index", your_embedding_model)
    
    # Example questions that would trigger different behaviors:
    
    # This would likely NOT trigger RAG (general knowledge)
    # answer = smart_rag.chat("What is machine learning?", verbose=True)
    # print(f"Answer: {answer}")
    
    # This would likely trigger RAG (specific/domain knowledge)
    # answer = smart_rag.chat("What are the specific implementation details of our authentication system?", verbose=True)
    # print(f"Answer: {answer}")
    
    # This would likely trigger RAG with generated mode (complex query)
    # answer = smart_rag.chat("How does our system handle user permissions and what are the security implications?", verbose=True)
    # print(f"Answer: {answer}")
    

In [28]:
from langchain_openai import OpenAIEmbeddings
smart_rag = create_smart_rag_tool("./data/faiss_index", OpenAIEmbeddings(model="text-embedding-3-large"))

In [None]:
response = smart_rag.chat("Какой пороговый уровень ОЗП", verbose=True)
print(response)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:LLM decided to use function: search_knowledge_base with args: {'query': 'пороговый уровень ОЗП', 'mode': 'original'}
INFO:__main__:Searching knowledge base with query: 'пороговый уровень ОЗП' in original mode
INFO:faiss.loader:Loading faiss with AVX2 support.
INFO:faiss.loader:Successfully loaded faiss with AVX2 support.
INFO:faiss:Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss.
INFO:rag_pipeline.rag_fusion_pipeline:Running in original query mode.
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



--- Smart RAG Tool Execution Report ---
User Query: Какой пороговый уровень ОЗП
Decision Made: search_needed
Search Performed: True
Total Cost: $0.0065
Total Tokens Used: 1720
Queries Used: ['пороговый уровень ОЗП']
Documents Retrieved: 4
--- End Report ---



'Пороговый уровень для оценки знаний педагогов (ОЗП) зависит от квалификационной категории педагога. Согласно предоставленным документам, пороговые уровни следующие:\n\n- Для квалификационной категории «педагог-стажер/педагог» - 50%;\n- Для квалификационной категории «педагог-модератор» - 60%;\n- Для квалификационной категории «педагог-эксперт» - 70%;\n- Для квалификационной категории «педагог-исследователь» - 80%;\n- Для квалификационной категории «педагог-мастер» - 90%.\n\nДля первых руководителей, заместителей руководителя организаций образования и методических кабинетов (центров) пороговый уровень составляет 70%.'

In [30]:
smart_rag.chat("Какая столица Франции?", verbose=True)

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:__main__:LLM decided no search needed, providing direct answer



--- Smart RAG Tool Execution Report ---
User Query: Какая столица Франции?
Decision Made: direct_answer
Search Performed: False
Total Cost: $0.0009
--- End Report ---



'Столица Франции — Париж.'

In [None]:
import json
import os
import logging
from typing import Dict, Any, Optional, Tuple, List
from pathlib import Path
import mimetypes
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.callbacks.manager import get_openai_callback

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class PDFRetrievalTool:
    def __init__(self, 
                 documents_json_path: str,
                 llm_params: Optional[Dict] = None):
        """
        Initialize the PDF Retrieval Tool
        
        Args:
            documents_json_path: Path to JSON file containing document mappings
            llm_params: Parameters for the LLM
        """
        self.documents_json_path = documents_json_path
        self.llm_params = llm_params or {"temperature": 0, "model": "gpt-4o"}
        self.llm = ChatOpenAI(**self.llm_params)
        
        # Load document mappings
        self.document_mappings = self._load_document_mappings()
        
        # Function definition for OpenAI function calling
        self.function_definition = {
            "name": "retrieve_document",
            "description": "Retrieve a specific PDF document by name when the user explicitly asks for a document, manual, guide, or file",
            "parameters": {
                "type": "object",
                "properties": {
                    "document_name": {
                        "type": "string",
                        "description": "The name or identifier of the document to retrieve"
                    },
                    "search_terms": {
                        "type": "array",
                        "items": {"type": "string"},
                        "description": "Alternative search terms or keywords to find the document if exact name doesn't match"
                    }
                },
                "required": ["document_name"]
            }
        }
        
        # Prompt to determine if document retrieval is needed
        self.decision_prompt = ChatPromptTemplate.from_messages([
            ("system", f"""You are an AI assistant that determines if a user is asking for a specific document/PDF file.

Available documents:
{self._format_available_documents()}

Use the retrieve_document function ONLY when:
1. The user explicitly asks for a document, manual, guide, report, or PDF
2. The user mentions wanting to "download", "get", "retrieve", or "access" a specific file
3. The user asks for documentation about a specific topic that matches our available documents

Examples of when to use retrieve_document:
- "Can I get the user manual?"
- "I need the API documentation"
- "Download the installation guide"
- "Show me the quarterly report"
- "I want to see the policy document"

Examples of when NOT to use retrieve_document:
- "Tell me about our API" (information request, not document request)
- "How do I install the software?" (question, not document request)
- "What's in the quarterly report?" (asking for summary, not the document itself)

When calling the function, try to match the user's request to the most appropriate document name from the available list. If unsure, include relevant search terms."""),
            ("user", "{user_query}")
        ])

    def _load_document_mappings(self) -> Dict[str, str]:
        """Load document name to path mappings from JSON file"""
        try:
            with open(self.documents_json_path, 'r', encoding='utf-8') as f:
                mappings = json.load(f)
            logger.info(f"Loaded {len(mappings)} document mappings")
            return mappings
        except FileNotFoundError:
            logger.error(f"Document mappings file not found: {self.documents_json_path}")
            return {}
        except json.JSONDecodeError as e:
            logger.error(f"Error parsing document mappings JSON: {e}")
            return {}

    def _format_available_documents(self) -> str:
        """Format available documents for the prompt"""
        if not self.document_mappings:
            return "No documents currently available."
        
        doc_list = []
        for doc_name in self.document_mappings.keys():
            doc_list.append(f"- {doc_name}")
        
        return "\n".join(doc_list)

    def _find_document(self, document_name: str, search_terms: Optional[List[str]] = None) -> Optional[Tuple[str, str]]:
        """
        Find a document by name or search terms
        
        Args:
            document_name: The requested document name
            search_terms: Alternative search terms
            
        Returns:
            Tuple of (matched_name, file_path) or None if not found
        """
        # Direct name match (case-insensitive)
        for doc_name, doc_path in self.document_mappings.items():
            if doc_name.lower() == document_name.lower():
                if os.path.exists(doc_path):
                    return doc_name, doc_path
                else:
                    logger.warning(f"Document found in mapping but file doesn't exist: {doc_path}")
        
        # Partial name match
        for doc_name, doc_path in self.document_mappings.items():
            if document_name.lower() in doc_name.lower() or doc_name.lower() in document_name.lower():
                if os.path.exists(doc_path):
                    return doc_name, doc_path
        
        # Search terms match
        if search_terms:
            for term in search_terms:
                for doc_name, doc_path in self.document_mappings.items():
                    if term.lower() in doc_name.lower():
                        if os.path.exists(doc_path):
                            return doc_name, doc_path
        
        return None

    def retrieve_document(self, document_name: str, search_terms: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Retrieve a document by name
        
        Args:
            document_name: Name of the document to retrieve
            search_terms: Alternative search terms
            
        Returns:
            Dictionary with document info and retrieval status
        """
        try:
            result = self._find_document(document_name, search_terms)
            
            if result is None:
                available_docs = list(self.document_mappings.keys())
                return {
                    "success": False,
                    "message": f"Document '{document_name}' not found.",
                    "available_documents": available_docs,
                    "suggestion": f"Available documents: {', '.join(available_docs[:5])}{'...' if len(available_docs) > 5 else ''}"
                }
            
            matched_name, file_path = result
            
            # Get file info
            file_stat = os.stat(file_path)
            file_size = file_stat.st_size
            file_size_mb = file_size / (1024 * 1024)
            
            # Get MIME type
            mime_type, _ = mimetypes.guess_type(file_path)
            
            logger.info(f"Successfully retrieved document: {matched_name} ({file_size_mb:.2f} MB)")
            
            return {
                "success": True,
                "document_name": matched_name,
                "file_path": file_path,
                "file_size": file_size,
                "file_size_mb": round(file_size_mb, 2),
                "mime_type": mime_type or "application/pdf",
                "message": f"Document '{matched_name}' is ready for download.",
                "download_info": {
                    "filename": os.path.basename(file_path),
                    "extension": Path(file_path).suffix
                }
            }
            
        except Exception as e:
            logger.error(f"Error retrieving document '{document_name}': {str(e)}")
            return {
                "success": False,
                "message": f"Error retrieving document: {str(e)}",
                "error": str(e)
            }

    def process_user_query(self, user_query: str) -> Tuple[str, Dict[str, Any]]:
        """
        Process a user query and determine if they're asking for a document
        
        Args:
            user_query: The user's request
            
        Returns:
            Tuple of (response_message, metadata)
        """
        try:
            # Create the chain with function calling
            chain = self.decision_prompt | self.llm.bind(
                functions=[self.function_definition],
                function_call="auto"
            )
            
            metadata = {
                "user_query": user_query,
                "document_requested": False,
                "document_retrieved": False,
                "total_cost": 0.0,
                "available_documents_count": len(self.document_mappings)
            }
            
            with get_openai_callback() as cb:
                response = chain.invoke({"user_query": user_query})
            
            # Track decision-making cost
            metadata["decision_cost"] = {
                "total_tokens": cb.total_tokens,
                "prompt_tokens": cb.prompt_tokens,
                "completion_tokens": cb.completion_tokens,
                "total_cost": cb.total_cost
            }
            metadata["total_cost"] += cb.total_cost
            
            # Check if the model decided to use function calling
            if hasattr(response, 'additional_kwargs') and 'function_call' in response.additional_kwargs:
                function_call = response.additional_kwargs['function_call']
                function_name = function_call['name']
                function_args = json.loads(function_call['arguments'])
                
                logger.info(f"LLM decided to retrieve document: {function_args}")
                metadata["document_requested"] = True
                metadata["function_args"] = function_args
                
                if function_name == "retrieve_document":
                    # Execute document retrieval
                    retrieval_result = self.retrieve_document(
                        document_name=function_args.get('document_name', ''),
                        search_terms=function_args.get('search_terms', [])
                    )
                    
                    metadata.update(retrieval_result)
                    metadata["document_retrieved"] = retrieval_result.get("success", False)
                    
                    if retrieval_result.get("success"):
                        doc_info = retrieval_result
                        response_message = (
                            f"✅ **{doc_info['document_name']}** is ready for download!\n\n"
                            f"📄 **File:** {doc_info['download_info']['filename']}\n"
                            f"📊 **Size:** {doc_info['file_size_mb']} MB\n"
                            f"📁 **Location:** {doc_info['file_path']}\n\n"
                            f"You can download the document from the specified location."
                        )
                        return response_message, metadata
                    else:
                        response_message = (
                            f"❌ {retrieval_result['message']}\n\n"
                            f"{retrieval_result.get('suggestion', '')}"
                        )
                        return response_message, metadata
                        
            else:
                # LLM decided user is not asking for a document
                response_message = (
                    "I don't see a request for a specific document in your message. "
                    f"If you're looking for a document, I have access to {len(self.document_mappings)} documents. "
                    "You can ask me to retrieve documents like 'Can I get the user manual?' or 'I need the API documentation'."
                )
                return response_message, metadata
                
        except Exception as e:
            logger.error(f"Error in process_user_query: {str(e)}")
            metadata["error"] = str(e)
            return f"I encountered an error while processing your request: {str(e)}", metadata

    def list_available_documents(self) -> str:
        """Return a formatted list of available documents"""
        if not self.document_mappings:
            return "No documents are currently available."
        
        doc_list = []
        for i, (doc_name, doc_path) in enumerate(self.document_mappings.items(), 1):
            file_exists = "✅" if os.path.exists(doc_path) else "❌"
            doc_list.append(f"{i}. {file_exists} **{doc_name}**")
        
        return f"**Available Documents ({len(self.document_mappings)}):**\n\n" + "\n".join(doc_list)

    def reload_document_mappings(self) -> bool:
        """Reload document mappings from the JSON file"""
        try:
            self.document_mappings = self._load_document_mappings()
            logger.info("Document mappings reloaded successfully")
            return True
        except Exception as e:
            logger.error(f"Failed to reload document mappings: {e}")
            return False


# Enhanced Smart RAG Tool that includes PDF retrieval
class EnhancedSmartRAGTool:
    def __init__(self, 
                 local_index_path: str,
                 embedding_model,
                 documents_json_path: str,
                 llm_params: Optional[Dict] = None):
        """
        Initialize the Enhanced Smart RAG Tool with PDF retrieval capability
        """
        from .smart_rag_tool import SmartRAGTool  # Import your existing Smart RAG Tool
        
        self.smart_rag = SmartRAGTool(local_index_path, embedding_model, llm_params)
        self.pdf_tool = PDFRetrievalTool(documents_json_path, llm_params)
        self.llm_params = llm_params or {"temperature": 0, "model": "gpt-4o"}

    def process_query(self, user_query: str, chat_context: Optional[str] = None, verbose: bool = False) -> str:
        """
        Process a user query with both document retrieval and RAG capabilities
        
        Args:
            user_query: The user's question or request
            chat_context: Optional conversation context
            verbose: Whether to print detailed information
            
        Returns:
            The response string
        """
        # First, check if user is asking for a document
        doc_response, doc_metadata = self.pdf_tool.process_user_query(user_query)
        
        if doc_metadata.get("document_requested", False):
            if verbose:
                print(f"\n--- Document Retrieval Attempt ---")
                print(f"Document Requested: {doc_metadata.get('document_requested', False)}")
                print(f"Document Retrieved: {doc_metadata.get('document_retrieved', False)}")
                print(f"Cost: ${doc_metadata.get('total_cost', 0):.4f}")
                print(f"--- End Document Retrieval ---\n")
            
            return doc_response
        
        # If not asking for a document, use Smart RAG
        return self.smart_rag.chat(user_query, chat_context, verbose)

    def list_documents(self) -> str:
        """List all available documents"""
        return self.pdf_tool.list_available_documents()

    def reload_documents(self) -> bool:
        """Reload document mappings"""
        return self.pdf_tool.reload_document_mappings()


# Usage Example
def create_enhanced_rag_tool(local_index_path: str, 
                           embedding_model, 
                           documents_json_path: str) -> EnhancedSmartRAGTool:
    """Factory function to create an Enhanced Smart RAG Tool instance"""
    return EnhancedSmartRAGTool(
        local_index_path=local_index_path,
        embedding_model=embedding_model,
        documents_json_path=documents_json_path,
        llm_params={"temperature": 0, "model": "gpt-4o"}
    )


# Example usage:
if __name__ == "__main__":
    # Example JSON structure for documents.json:
    # {
    #     "User Manual": "/path/to/user_manual.pdf",
    #     "API Documentation": "/path/to/api_docs.pdf",
    #     "Installation Guide": "/path/to/install_guide.pdf",
    #     "Quarterly Report Q1 2024": "/path/to/q1_report.pdf"
    # }
    
    # Initialize the enhanced tool
    # enhanced_rag = create_enhanced_rag_tool(
    #     local_index_path="./faiss_index",
    #     embedding_model=your_embedding_model,
    #     documents_json_path="./documents.json"
    # )
    
    # Example interactions:
    # enhanced_rag.process_query("Can I get the user manual?", verbose=True)  # Document retrieval
    # enhanced_rag.process_query("How does authentication work?", verbose=True)  # RAG search
    # enhanced_rag.process_query("What is machine learning?", verbose=True)  # Direct answer
    # print(enhanced_rag.list_documents())  # List available documents
    
    pass