In [None]:
import unittest
from unittest import mock
from unittest.mock import Mock, patch, AsyncMock
from http import HTTPStatus
import os

from your_package.utils.conversation_utils import ConversationManagementUtils


class TestConversationManagementUtils(unittest.TestCase):
    def setUp(self):
        self.app_id = "test_app"
        self.request_handler = Mock()
        self.utils = ConversationManagementUtils(
            app_id=self.app_id,
            request_handler=self.request_handler,
            semaphore_limit=10
        )
    
    def test_create_headers(self):
        headers = self.utils._create_headers("test_sso")
        self.assertEqual(headers["Cookie"], "SSO=test_sso")
        self.assertEqual(headers["app-id"], self.app_id)
    
    def test_get_all_conversations_success(self):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.OK
        mock_response.json.return_value = [
            {"conversationId": "conv1"},
            {"conversationId": "conv2"}
        ]
        self.request_handler.get.return_value = mock_response
        
        result = self.utils.get_all_conversations("test_sso")
        self.assertEqual(result, ["conv1", "conv2"])
        
        self.request_handler.get.assert_called_once()
        args, kwargs = self.request_handler.get.call_args
        self.assertTrue(kwargs["path"].endswith("/conversations"))
    
    def test_get_all_conversations_failure(self):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.BAD_REQUEST
        mock_response.text = "Error text"
        self.request_handler.get.return_value = mock_response
        
        result = self.utils.get_all_conversations("test_sso")
        self.assertEqual(result, [])
    
    def test_get_all_conversations_exception(self):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.OK
        mock_response.json.side_effect = Exception("Test error")
        self.request_handler.get.return_value = mock_response
        
        result = self.utils.get_all_conversations("test_sso")
        self.assertEqual(result, [])
    
    def test_get_all_conversations_none_response(self):
        self.request_handler.get.return_value = None
        
        result = self.utils.get_all_conversations("test_sso")
        self.assertEqual(result, [])
    
    async def test_async_get_all_conversations_success(self):
        mock_response = Mock()
        mock_response.status = HTTPStatus.OK
        mock_response.json = AsyncMock(return_value=[
            {"conversationId": "conv1"},
            {"conversationId": "conv2"}
        ])
        self.request_handler.async_get = AsyncMock(return_value=mock_response)
        
        result = await self.utils.async_get_all_conversations("test_sso")
        self.assertEqual(result, ["conv1", "conv2"])
    
    async def test_async_get_all_conversations_failure(self):
        mock_response = Mock()
        mock_response.status = HTTPStatus.BAD_REQUEST
        mock_response.text = AsyncMock(return_value="Error text")
        self.request_handler.async_get = AsyncMock(return_value=mock_response)
        
        result = await self.utils.async_get_all_conversations("test_sso")
        self.assertEqual(result, [])
    
    async def test_async_get_all_conversations_exception(self):
        mock_response = Mock()
        mock_response.status = HTTPStatus.OK
        mock_response.json = AsyncMock(side_effect=Exception("Test error"))
        self.request_handler.async_get = AsyncMock(return_value=mock_response)
        
        result = await self.utils.async_get_all_conversations("test_sso")
        self.assertEqual(result, [])
    
    def test_get_conversation_id_success(self):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.OK
        mock_response.json.return_value = {
            "conversationId": "conv123",
            "data": "test_data"
        }
        self.request_handler.post.return_value = mock_response
        
        conv_id, data = self.utils.get_conversation_id("test_sso", {"pref": "value"})
        self.assertEqual(conv_id, "conv123")
        self.assertEqual(data["data"], "test_data")
        
        args, kwargs = self.request_handler.post.call_args
        self.assertEqual(kwargs["payload"]["preferences"], {"pref": "value"})
    
    def test_get_conversation_id_failure(self):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.BAD_REQUEST
        mock_response.text = "Error text"
        self.request_handler.post.return_value = mock_response
        
        conv_id, data = self.utils.get_conversation_id("test_sso")
        self.assertIsNone(conv_id)
        self.assertEqual(data, {})
    
    def test_get_conversation_id_exception(self):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.OK
        mock_response.json.side_effect = Exception("Test error")
        self.request_handler.post.return_value = mock_response
        
        conv_id, data = self.utils.get_conversation_id("test_sso")
        self.assertIsNone(conv_id)
        self.assertEqual(data, {})
    
    @patch('builtins.open', new_callable=mock.mock_open, read_data=b'test file content')
    def test_upload_success(self, mock_open):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.OK
        mock_response.json.return_value = {"document_id": "doc123"}
        self.request_handler.post.return_value = mock_response
        
        result = self.utils.upload("test_sso", "test/path.pdf")
        self.assertEqual(result, "doc123")
        
        mock_open.assert_called_once_with("test/path.pdf", "rb")
        args, kwargs = self.request_handler.post.call_args
        self.assertIn("files", kwargs)
    
    def test_upload_unsupported_file_type(self):
        result = self.utils.upload("test_sso", "test/path.unsupported")
        self.assertIsNone(result)
        self.request_handler.post.assert_not_called()
    
    @patch('builtins.open', new_callable=mock.mock_open, read_data=b'test file content')
    def test_upload_api_error(self, mock_open):
        mock_response = Mock()
        mock_response.status_code = HTTPStatus.BAD_REQUEST
        mock_response.text = "Error text"
        self.request_handler.post.return_value = mock_response
        
        result = self.utils.upload("test_sso", "test/path.pdf")
        self.assertIsNone(result)
    
    @patch('builtins.open', side_effect=IOError("File not found"))
    def test_upload_file_error(self, mock_open):
        result = self.utils.upload("test_sso", "nonexistent/path.pdf")
        self.assertIsNone(result)
        self.request_handler.post.assert_not_called()


class AsyncMock(Mock):
    async def __call__(self, *args, **kwargs):
        return super(AsyncMock, self).__call__(*args, **kwargs)


if __name__ == "__main__":
    unittest.main()

ModuleNotFoundError: No module named 'your_package'

In [None]:
"""
Utilities for defining Gemini API functions using Pydantic models.
"""

import json
from typing import Dict, List, Any, Optional, Callable, Type, Union, get_type_hints
from pydantic import BaseModel, Field
from loguru import logger


class FunctionDefinition:
    """
    Function definition using Pydantic models for parameters and return types.
    """
    
    def __init__(
        self,
        func: Callable,
        params_model: Type[BaseModel],
        name: Optional[str] = None,
        description: Optional[str] = None,
        result_model: Optional[Type[BaseModel]] = None,
    ):
        """
        Define a function that can be called by Gemini API.
        
        Args:
            func: The Python function to execute
            params_model: Pydantic model defining parameters
            name: Optional function name (defaults to function name)
            description: Optional function description
            result_model: Optional Pydantic model for return type
        """
        self.func = func
        self.params_model = params_model
        self.name = name or func.__name__
        self.description = description or func.__doc__ or ""
        self.result_model = result_model
        
    def to_gemini_format(self) -> Dict[str, Any]:
        """
        Convert to Gemini API function format.
        
        Returns:
            Function definition in Gemini API format
        """
        # Get parameter schema from Pydantic model
        param_schema = self.params_model.model_json_schema()
        
        # Create Gemini schema
        gemini_schema = {
            "name": self.name,
            "description": self.description,
            "parameters": self._convert_schema_to_gemini(param_schema)
        }
        
        # Add return schema if provided
        if self.result_model:
            result_schema = self.result_model.model_json_schema()
            gemini_schema["response"] = self._convert_schema_to_gemini(result_schema)
            
        return gemini_schema
        
    def _convert_schema_to_gemini(self, schema: Dict[str, Any]) -> Dict[str, Any]:
        """
        Convert JSON schema to Gemini API schema format.
        
        Args:
            schema: JSON schema from Pydantic model
            
        Returns:
            Schema in Gemini API format
        """
        gemini_schema = {
            "type": "OBJECT",
            "properties": {}
        }
        
        # Convert properties
        for prop_name, prop_info in schema.get("properties", {}).items():
            gemini_schema["properties"][prop_name] = self._convert_property(prop_info)
            
        # Add required fields
        if "required" in schema:
            gemini_schema["required"] = schema["required"]
            
        return gemini_schema
        
    def _convert_property(self, property_info: Dict[str, Any]) -> Dict[str, Any]:
        """
        Convert a property definition to Gemini format.
        
        Args:
            property_info: Property schema from Pydantic
            
        Returns:
            Property in Gemini format
        """
        prop_type = property_info.get("type")
        
        if prop_type == "string":
            gemini_prop = {"type": "STRING"}
        elif prop_type == "integer":
            gemini_prop = {"type": "INTEGER"}
        elif prop_type == "number":
            gemini_prop = {"type": "NUMBER"}
        elif prop_type == "boolean":
            gemini_prop = {"type": "BOOLEAN"}
        elif prop_type == "array":
            gemini_prop = {
                "type": "ARRAY",
                "items": self._convert_property(property_info.get("items", {}))
            }
        elif prop_type == "object":
            gemini_prop = {
                "type": "OBJECT",
                "properties": {}
            }
            for name, info in property_info.get("properties", {}).items():
                gemini_prop["properties"][name] = self._convert_property(info)
        else:
            gemini_prop = {"type": "STRING"}  # Default to string
            
        # Add description if available
        if "description" in property_info:
            gemini_prop["description"] = property_info["description"]
            
        # Add enum if available
        if "enum" in property_info:
            gemini_prop["enum"] = property_info["enum"]
            
        return gemini_prop
        
    def execute(self, args: Dict[str, Any]) -> Any:
        """
        Execute the function with provided arguments.
        
        Args:
            args: Arguments for the function
            
        Returns:
            Function execution result
        """
        try:
            # Validate arguments using Pydantic model
            params = self.params_model(**args)
            
            # Call the function with validated parameters
            if isinstance(params, BaseModel):
                # Pass model as dict if function expects individual parameters
                arg_spec = get_type_hints(self.func)
                if len(arg_spec) > 1 or next(iter(arg_spec.keys())) != "params":
                    result = self.func(**params.model_dump())
                else:
                    # Pass model directly if function expects it
                    result = self.func(params)
            else:
                result = self.func(**args)
                
            # Validate result if result model is provided
            if self.result_model and not isinstance(result, self.result_model):
                # Try to convert result to model
                return self.result_model(**result)
                
            return result
            
        except Exception as e:
            logger.error(f"Error executing function {self.name}: {str(e)}")
            return {"error": str(e)}


class GeminiUtils:
    """
    Utilities for working with Gemini API formats and conversions.
    """
    
    @staticmethod
    def pydantic_to_schema(schema: Any) -> Dict[str, Any]:
        """
        Convert Pydantic schema to Gemini API schema format.
        
        Args:
            schema: Pydantic schema to convert
            
        Returns:
            Schema in Gemini API format
        """
        if hasattr(schema, "model_json_schema"):
            function_def = FunctionDefinition(
                func=lambda: None,  # Dummy function
                params_model=schema,
            )
            return function_def._convert_schema_to_gemini(schema.model_json_schema())
            
        # If not a Pydantic model, return as is
        return schema
    
    @staticmethod
    def extract_function_calls(response: Any) -> List[Dict[str, Any]]:
        """
        Extract function calls from model response.
        
        Args:
            response: Model response
            
        Returns:
            List of function calls
        """
        # Parse response if it's a string
        if isinstance(response, str):
            try:
                response_data = json.loads(response)
            except json.JSONDecodeError:
                # Not JSON, no function calls
                return []
        elif isinstance(response, dict):
            response_data = response
        else:
            return []
        
        # Check for function calls in different formats
        function_calls = []
        
        # Format 1: functionCalls array
        if "functionCalls" in response_data:
            return response_data["functionCalls"]
            
        # Format 2: function_calls array
        if "function_calls" in response_data:
            return response_data["function_calls"]
            
        # Format 3: functionCall in candidates
        if "candidates" in response_data:
            for candidate in response_data.get("candidates", []):
                if isinstance(candidate, dict) and "functionCall" in candidate:
                    function_calls.append(candidate["functionCall"])
                    
        # Format 4: toolCalls in various schemas
        tool_calls = response_data.get("toolCalls", []) or response_data.get("tool_calls", [])
        if tool_calls:
            for tool in tool_calls:
                if isinstance(tool, dict) and "function" in tool:
                    function_info = tool["function"]
                    # Convert to standard format
                    function_call = {
                        "name": function_info.get("name"),
                        "args": function_info.get("arguments", {})
                    }
                    
                    # Parse args if it's a string
                    if isinstance(function_call["args"], str):
                        try:
                            function_call["args"] = json.loads(function_call["args"])
                        except json.JSONDecodeError:
                            function_call["args"] = {}
                            
                    function_calls.append(function_call)
        
        return function_calls