In [None]:
import unittest
import json
from unittest.mock import Mock, patch, MagicMock
from typing import Dict, List, Any

from your_package.provider.inference_service import InferenceService


class TestInferenceService(unittest.TestCase):
    def setUp(self):
        self.app_id = "test_app"
        self.request_handler = Mock()
        self.preferences = {"temperature": 0.7}
        self.service = InferenceService(
            app_id=self.app_id,
            request_handler=self.request_handler,
            preferences=self.preferences
        )
        
    def test_create_cms_headers(self):
        headers = self.service._create_cms_headers("test_sso")
        self.assertEqual(headers["Cookie"], "SSO=test_sso")
        self.assertEqual(headers["app-id"], self.app_id)
        
    def test_set_cms_preferences(self):
        payload = {"conversation": {}}
        updated = self.service._set_cms_preferences(payload)
        self.assertEqual(updated["conversation"]["preferences"], self.preferences)
        
    @patch('your_package.utils.gemini_utils.GeminiUtils.pydantic_to_schema')
    def test_get_cms_payload_basic(self, mock_pydantic_convert):
        mock_pydantic_convert.return_value = {"type": "OBJECT"}
        
        question = "test question"
        schema = {"custom_schema": True}
        
        payload = self.service._get_cms_payload(question=question, schema=schema)
        
        self.assertEqual(payload["question"]["question"], question)
        self.assertEqual(payload["conversation"]["title"], self.app_id)
        self.assertEqual(
            payload["question"]["responseConfig"]["responseSchema"]["type"], 
            "OBJECT"
        )
        
    @patch('your_package.utils.gemini_utils.GeminiUtils.python_to_function')
    def test_get_cms_payload_with_tools(self, mock_python_to_function):
        mock_function = Mock()
        mock_python_to_function.return_value = {"name": "test_func"}
        
        payload = self.service._get_cms_payload(
            question="test", 
            tools=[mock_function]
        )
        
        self.assertIn("llmTools", payload["question"])
        self.assertEqual(len(payload["question"]["llmTools"]), 1)
        self.assertEqual(payload["question"]["llmTools"][0]["name"], "test_func")
        
    def test_get_cms_payload_with_document(self):
        payload = self.service._get_cms_payload(
            question="test",
            document_id="doc123",
            document_type="pdf"
        )
        
        self.assertIn("questionContext", payload)
        self.assertEqual(payload["questionContext"]["documentReference"], "doc123")
        self.assertEqual(payload["questionContext"]["documentType"], "application/pdf")
        
    def test_get_cms_payload_unsupported_document_type(self):
        payload = self.service._get_cms_payload(
            question="test",
            document_id="doc123",
            document_type="unknown"
        )
        
        self.assertNotIn("questionContext", payload)
        
    def test_get_cms_payload_with_function_results(self):
        function_results = [{"name": "test_func", "result": "test_result"}]
        
        payload = self.service._get_cms_payload(
            question="test question",
            function_results=function_results
        )
        
        self.assertIn("Content:", payload["question"]["question"])
        
    def test_process_response_none(self):
        result = self.service._process_response(None)
        self.assertIsNone(result)
        
    def test_process_response_with_json_method(self):
        mock_response = Mock()
        mock_response.json.return_value = {"answer": "test_answer"}
        
        result = self.service._process_response(mock_response)
        self.assertEqual(result, {"answer": "test_answer"})
        
    def test_process_response_direct_dict(self):
        response = {"answer": "test_answer"}
        result = self.service._process_response(response)
        self.assertEqual(result, response)
        
    def test_process_response_exception(self):
        mock_response = Mock()
        mock_response.json.side_effect = Exception("Test error")
        mock_response.text = "Raw text"
        
        result = self.service._process_response(mock_response)
        self.assertEqual(result, "Raw text")
        
    async def test_process_async_response(self):
        mock_response = Mock()
        mock_response.json = AsyncMock(return_value={"answer": "test_answer"})
        
        result = await self.service._process_async_response(mock_response)
        self.assertEqual(result, {"answer": "test_answer"})
        
    async def test_process_async_response_exception(self):
        mock_response = Mock()
        mock_response.json = AsyncMock(side_effect=Exception("Test error"))
        mock_response.text = AsyncMock(return_value="Raw text")
        
        result = await self.service._process_async_response(mock_response)
        self.assertEqual(result, "Raw text")
        
    def test_cms_inference_basic(self):
        mock_response = Mock()
        mock_response.json.return_value = {"answer": "test_answer"}
        self.request_handler.post.return_value = mock_response
        
        result = self.service.cms_inference(
            question="test question",
            sso="test_sso"
        )
        
        self.assertEqual(result, "test_answer")
        self.request_handler.post.assert_called_once()
        
    def test_cms_inference_json_answer(self):
        json_answer = json.dumps({"result": "test_result"})
        mock_response = Mock()
        mock_response.json.return_value = {"answer": json_answer}
        self.request_handler.post.return_value = mock_response
        
        result = self.service.cms_inference(
            question="test question",
            sso="test_sso"
        )
        
        self.assertEqual(result, {"result": "test_result"})
        
    async def test_async_cms_inference(self):
        mock_response = Mock()
        mock_response.json = AsyncMock(return_value={"answer": "test_answer"})
        self.request_handler.async_post = AsyncMock(return_value=mock_response)
        
        result = await self.service.async_cms_inference(
            question="test question",
            sso="test_sso"
        )
        
        self.assertEqual(result, "test_answer")
        
    def test_stream_cms_inference(self):
        mock_response = Mock()
        mock_response.json.return_value = {"answer": "test_answer"}
        self.request_handler.post.return_value = mock_response
        
        result = self.service.stream_cms_inference(
            conversation_id="conv123",
            payload={},
            question="test question",
            sso="test_sso"
        )
        
        self.assertEqual(result, "test_answer")
        
    async def test_async_stream_cms_inference(self):
        mock_response = Mock()
        mock_response.json = AsyncMock(return_value={"answer": "test_answer"})
        self.request_handler.async_post = AsyncMock(return_value=mock_response)
        
        result = await self.service.async_stream_cms_inference(
            conversation_id="conv123",
            payload={},
            question="test question",
            sso="test_sso"
        )
        
        self.assertEqual(result, "test_answer")


class AsyncMock(MagicMock):
    async def __call__(self, *args, **kwargs):
        return super(AsyncMock, self).__call__(*args, **kwargs)


if __name__ == '__main__':
    unittest.main()