In [4]:
"""
Inference service module for CMS and DSML backends.
"""

import json
from typing import Optional, Dict, List, Any, Literal, Callable, Union
from loguru import logger

from ..utils.request_handler import RequestHandler
from ..utils.gemini_utils import GeminiUtils


class InferenceService:
    """
    Inference service for language model interactions.
    
    This class provides methods to interact with different inference backends,
    including CMS and DSML, with support for function/tool calling.
    """

    def __init__(
        self,
        app_id: str,
        request_handler: RequestHandler,
        preferences: Dict[str, Any],
    ):
        """
        Initialize the inference service.
        
        Args:
            app_id: Application identifier
            request_handler: Request handler for API calls
            preferences: User preferences for inference
        """
        self.app_id = app_id
        self.request_handler = request_handler
        self.preferences = preferences or {}
        
        # Document type mapping for wider support
        self.doc_type_mapping = {
            "pdf": "application/pdf", 
            "jpg": "image/jpg",
            "jpeg": "image/jpeg",
            "png": "image/png"
        }

    def _create_cms_headers(self, sso: str) -> Dict[str, str]:
        """
        Create standard headers for CMS API requests.
        
        Args:
            sso: User's SSO cookie
            
        Returns:
            Dictionary of headers
        """
        return {"Cookie": "SSO=" + sso, "app-id": self.app_id}

    def _set_cms_preferences(self, payload: Dict[str, Any]) -> Dict[str, Any]:
        """
        Set user preferences in the payload.
        
        Args:
            payload: Request payload to modify
            
        Returns:
            Modified payload with preferences
        """
        if "conversation" in payload:
            payload["conversation"]["preferences"] = self.preferences
        return payload

    def _get_cms_payload(
        self,
        question: str,
        schema: Optional[Any] = None,
        tools: Optional[List[Union[Dict[str, Any], Callable]]] = None,
        document_id: Optional[str] = None,
        document_type: Optional[Literal["pdf", "jpg", "jpeg", "png"]] = None,
        payload: Optional[Dict[str, Any]] = None,
        function_results: Optional[List[Dict[str, Any]]] = None,
    ) -> Dict[str, Any]:
        """
        Create the payload for CMS API requests.
        
        Args:
            question: User question or prompt
            schema: Optional response schema (Pydantic model or dict)
            tools: Optional LLM tools (functions or definitions)
            document_id: Optional document reference ID
            document_type: Optional document type
            payload: Optional base payload to modify
            function_results: Optional results from function execution
            
        Returns:
            Complete payload for the API request
        """
        # Start with provided payload or create new one
        if payload:
            # For stream endpoints that need different structure
            payload["question"] = question
        else:
            # Standard structure for other endpoints
            payload = {"conversation": {"title": f"{self.app_id}"}}
            payload["question"] = {"question": f"{question}"}
            payload = self._set_cms_preferences(payload)

        # Convert schema if provided
        if schema:
            if "responseConfig" not in payload["question"]:
                payload["question"]["responseConfig"] = {}
                
            # Convert schema if it's a Pydantic model
            gemini_schema = GeminiUtils.pydantic_to_schema(schema)
            
            payload["question"]["responseConfig"].update({
                "responseType": "application/json",
                "responseSchema": gemini_schema,
            })

        # Convert and add tools if provided
        if tools:
            # Convert functions to tool definitions if needed
            tool_definitions = []
            for tool in tools:
                if callable(tool):
                    # Convert function to tool definition
                    tool_definition = GeminiUtils.python_to_function(tool)
                    tool_definitions.append(tool_definition)
                elif isinstance(tool, dict):
                    # Already a tool definition
                    tool_definitions.append(tool)
                else:
                    logger.warning(f"Ignoring unsupported tool type: {type(tool)}")
                    
            if tool_definitions:
                payload["question"]["llmTools"] = tool_definitions

        # Add document context if provided
        if document_id and document_type:
            if document_type not in self.doc_type_mapping:
                logger.warning(f"Unsupported document type: {document_type}")
            else:
                payload["questionContext"] = {
                    "type": "LEXDOCUMENT",
                    "documentReference": document_id,
                    "documentType": self.doc_type_mapping[document_type],
                }
                
        # Add function results if provided
        if function_results:
            # Add results to context
            if "Content:" not in question:
                # Format function results for inclusion in the question context
                results_str = json.dumps(function_results)
                payload["question"]["question"] = f"{question}\nContent: {results_str}"

        return payload
    
    def _process_response(self, response: Any) -> Optional[Any]:
        """
        Process API response with error handling.
        
        Args:
            response: API response
            
        Returns:
            Processed response or None
        """
        if response is None:
            logger.error("Failed to get response after retries")
            return None
            
        try:
            # If response has json method, use it
            if hasattr(response, 'json') and callable(response.json):
                return response.json()
            # If response is already processed, return it
            return response
        except Exception as e:
            logger.error(f"Error parsing response: {e}")
            # Try to return as text if json fails
            try:
                if hasattr(response, 'text'):
                    if callable(response.text):
                        return response.text()
                    return response.text
            except Exception:
                pass
            return None

    async def _process_async_response(self, response: Any) -> Optional[Any]:
        """
        Process async API response with error handling.
        
        Args:
            response: Async API response
            
        Returns:
            Processed response or None
        """
        if response is None:
            logger.error("Failed to get async response after retries")
            return None
            
        try:
            # Try to parse as JSON first
            return await response.json()
        except Exception as e:
            logger.error(f"Error parsing async response as JSON: {e}")
            # Try to return as text if json fails
            try:
                return await response.text()
            except Exception:
                pass
            return None

    def cms_inference(
        self,
        question: str,
        sso: str,
        schema: Optional[Any] = None,
        tools: Optional[List[Union[Dict[str, Any], Callable]]] = None,
        document_id: Optional[str] = None,
        document_type: Optional[Literal["pdf", "jpg", "jpeg", "png"]] = None,
        function_results: Optional[List[Dict[str, Any]]] = None,
    ) -> Any:
        """
        Make a CMS inference request.
        
        Args:
            question: User question or prompt
            sso: User's SSO cookie
            schema: Optional response schema (Pydantic model or dict)
            tools: Optional LLM tools (functions or definitions)
            document_id: Optional document reference ID
            document_type: Optional document type
            function_results: Optional results from function execution
            
        Returns:
            Model response
        """
        path = f"/api/{self.app_id}/conversations_with_question"
        headers = self._create_cms_headers(sso)
        payload = self._get_cms_payload(
            question=question,
            schema=schema,
            tools=tools,
            document_id=document_id,
            document_type=document_type,
            function_results=function_results,
        )

        logger.debug(f"CMS Inference Payload: {payload}")
        response = self.request_handler.post(
            path=path, headers=headers, payload=payload
        )
        logger.debug(f"CMS Inference Response: {response}")
        
        processed_response = self._process_response(response)
        
        # Extract answer if available
        if isinstance(processed_response, dict) and "answer" in processed_response:
            # Check if answer is JSON string
            answer = processed_response["answer"]
            try:
                return json.loads(answer)
            except (json.JSONDecodeError, TypeError):
                return answer
                
        return processed_response

    async def async_cms_inference(
        self,
        question: str,
        sso: str,
        schema: Optional[Any] = None,
        tools: Optional[List[Union[Dict[str, Any], Callable]]] = None,
        document_id: Optional[str] = None,
        document_type: Optional[Literal["pdf", "jpg", "jpeg", "png"]] = None,
        function_results: Optional[List[Dict[str, Any]]] = None,
    ) -> Any:
        """
        Make an asynchronous CMS inference request.
        
        Args:
            question: User question or prompt
            sso: User's SSO cookie
            schema: Optional response schema (Pydantic model or dict)
            tools: Optional LLM tools (functions or definitions)
            document_id: Optional document reference ID
            document_type: Optional document type
            function_results: Optional results from function execution
            
        Returns:
            Model response
        """
        path = f"/api/{self.app_id}/conversations_with_question"
        headers = self._create_cms_headers(sso)
        payload = self._get_cms_payload(
            question=question,
            schema=schema,
            tools=tools,
            document_id=document_id,
            document_type=document_type,
            function_results=function_results,
        )
        
        logger.debug(f"Async CMS Inference Payload: {payload}")
        response = await self.request_handler.async_post(
            path=path, headers=headers, payload=payload
        )
        logger.debug(f"Async CMS Inference Response: {response}")
        
        processed_response = await self._process_async_response(response)
        
        # Extract answer if available
        if isinstance(processed_response, dict) and "answer" in processed_response:
            # Check if answer is JSON string
            answer = processed_response["answer"]
            try:
                return json.loads(answer)
            except (json.JSONDecodeError, TypeError):
                return answer
                
        return processed_response

    def stream_cms_inference(
        self,
        conversation_id: str,
        payload: Dict[str, Any],
        question: str,
        sso: str,
        schema: Optional[Any] = None,
        tools: Optional[List[Union[Dict[str, Any], Callable]]] = None,
        function_results: Optional[List[Dict[str, Any]]] = None,
    ) -> Any:
        """
        Make a streaming CMS inference request.
        
        Args:
            conversation_id: Existing conversation ID
            payload: Base payload with conversation details
            question: User question or prompt
            sso: User's SSO cookie
            schema: Optional response schema (Pydantic model or dict)
            tools: Optional LLM tools (functions or definitions)
            function_results: Optional results from function execution
            
        Returns:
            Model response
        """
        if not conversation_id:
            logger.error("Cannot stream without a valid conversation ID")
            return None
            
        path = f"/api/{self.app_id}/conversations/{conversation_id}/questions/stream"
        headers = self._create_cms_headers(sso)
        
        # Prepare payload for streaming endpoint
        prepared_payload = self._get_cms_payload(
            question=question, 
            schema=schema, 
            tools=tools, 
            payload=payload,
            function_results=function_results,
        )
        logger.debug(f"Stream CMS Inference Payload: {prepared_payload}")

        response = self.request_handler.post(
            path=path, headers=headers, payload=prepared_payload
        )
        logger.debug(f"Stream CMS Inference Response: {response}")
        
        processed_response = self._process_response(response)
        
        # Extract answer if available
        if isinstance(processed_response, dict) and "answer" in processed_response:
            # Check if answer is JSON string
            answer = processed_response["answer"]
            try:
                return json.loads(answer)
            except (json.JSONDecodeError, TypeError):
                return answer
                
        return processed_response

    async def async_stream_cms_inference(
        self,
        conversation_id: str,
        payload: Dict[str, Any],
        question: str,
        sso: str,
        schema: Optional[Any] = None,
        tools: Optional[List[Union[Dict[str, Any], Callable]]] = None,
        function_results: Optional[List[Dict[str, Any]]] = None,
    ) -> Any:
        """
        Make an asynchronous streaming CMS inference request.
        
        Args:
            conversation_id: Existing conversation ID
            payload: Base payload with conversation details
            question: User question or prompt
            sso: User's SSO cookie
            schema: Optional response schema (Pydantic model or dict)
            tools: Optional LLM tools (functions or definitions)
            function_results: Optional results from function execution
            
        Returns:
            Model response
        """
        if not conversation_id:
            logger.error("Cannot stream without a valid conversation ID")
            return None
            
        # Note: This uses /questions endpoint not /questions/stream as seen in original code
        path = f"/api/{self.app_id}/conversations/{conversation_id}/questions"
        headers = self._create_cms_headers(sso)
        
        # Prepare payload for streaming endpoint
        prepared_payload = self._get_cms_payload(
            question=question, 
            schema=schema, 
            tools=tools, 
            payload=payload,
            function_results=function_results,
        )
        logger.debug(f"Async Stream CMS Inference Payload: {prepared_payload}")

        response = await self.request_handler.async_post(
            path=path, headers=headers, payload=prepared_payload
        )
        logger.debug(f"Async Stream CMS Inference Response: {response}")
        
        processed_response = await self._process_async_response(response)
        
        # Extract answer if available
        if isinstance(processed_response, dict) and "answer" in processed_response:
            # Check if answer is JSON string
            answer = processed_response["answer"]
            try:
                return json.loads(answer)
            except (json.JSONDecodeError, TypeError):
                return answer
                
        return processed_response