In [None]:
import unittest
from unittest.mock import Mock, patch, AsyncMock
import json

from your_package.llm import LLM


class TestLLM(unittest.TestCase):
    def setUp(self):
        self.app_id = "test_app"
        self.env = "test_env"
        self.mock_cms = Mock()
        
        self.llm = LLM(
            app_id=self.app_id,
            env=self.env,
            cms=self.mock_cms,
            provider_type="cms"
        )
    
    def test_initialization(self):
        self.assertEqual(self.llm.app_id, self.app_id)
        self.assertEqual(self.llm.env, self.env)
        self.assertEqual(self.llm.cms, self.mock_cms)
        self.assertEqual(self.llm.provider_type, "cms")
    
    @patch('your_package.llm.ConversationManagementService')
    @patch('your_package.llm.get_preferences')
    def test_init_cms(self, mock_get_preferences, mock_cms):
        mock_get_preferences.return_value = {"temperature": 0.7}
        mock_cms.return_value = "mock_cms_instance"
        
        def test_func(): pass
        
        llm = LLM.init(
            app_id=self.app_id,
            env=self.env,
            model_name="test_model",
            temperature=0.7,
            tools=[test_func]
        )
        
        self.assertEqual(llm.cms, "mock_cms_instance")
        mock_cms.assert_called_once()
        mock_get_preferences.assert_called_once()
    
    def test_init_dsml_not_implemented(self):
        with self.assertRaises(NotImplementedError):
            LLM.init(
                app_id=self.app_id,
                provider_type="dsml"
            )
    
    def test_init_invalid_provider(self):
        with self.assertRaises(ValueError):
            LLM.init(
                app_id=self.app_id,
                provider_type="invalid"
            )
    
    def test_get_provider_default(self):
        provider = self.llm._get_provider()
        self.assertEqual(provider, "cms")
    
    def test_get_provider_override(self):
        provider = self.llm._get_provider(provider="cms")
        self.assertEqual(provider, "cms")
    
    def test_get_provider_no_cms(self):
        llm = LLM(app_id=self.app_id, env=self.env, provider_type="cms")
        
        with self.assertRaises(ValueError):
            llm._get_provider()
    
    def test_get_provider_invalid(self):
        with self.assertRaises(ValueError):
            self.llm._get_provider(provider="invalid")
    
    @patch('your_package.llm.gs_auth')
    def test_get_gssso(self, mock_gs_auth):
        mock_gs_auth.get_gssso.return_value = "test_sso"
        
        result = self.llm._get_gssso()
        self.assertEqual(result, "test_sso")
    
    @patch('your_package.llm.gs_auth')
    def test_get_gssso_error(self, mock_gs_auth):
        mock_gs_auth.get_gssso.side_effect = Exception("Test error")
        
        with self.assertRaises(ValueError):
            self.llm._get_gssso()
    
    @patch.object(LLM, '_get_gssso')
    def test_invoke_with_tools(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.inference_with_tools.return_value = "test_response"
        
        result = self.llm.invoke(
            user_input="test_input",
            use_tools=True
        )
        
        self.assertEqual(result, "test_response")
        self.mock_cms.inference_with_tools.assert_called_once()
    
    @patch.object(LLM, '_get_gssso')
    def test_invoke_without_tools(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.inference.return_value = "test_response"
        
        result = self.llm.invoke(
            user_input="test_input",
            use_tools=False
        )
        
        self.assertEqual(result, "test_response")
        self.mock_cms.inference.assert_called_once()
    
    @patch.object(LLM, '_get_gssso')
    def test_invoke_with_gssso(self, mock_get_gssso):
        self.mock_cms.inference_with_tools.return_value = "test_response"
        
        result = self.llm.invoke(
            user_input="test_input",
            sso="provided_gssso"
        )
        
        self.assertEqual(result, "test_response")
        mock_get_gssso.assert_not_called()
    
    @patch.object(LLM, '_get_gssso')
    def test_invoke_with_document(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.inference_with_tools.return_value = "test_response"
        
        result = self.llm.invoke(
            user_input="test_input",
            document_id="doc123",
            document_type="pdf"
        )
        
        self.assertEqual(result, "test_response")
        
        _, kwargs = self.mock_cms.inference_with_tools.call_args
        self.assertEqual(kwargs["document_id"], "doc123")
        self.assertEqual(kwargs["document_type"], "pdf")
    
    def test_invoke_dsml_not_implemented(self):
        llm = LLM(
            app_id=self.app_id,
            env=self.env,
            provider_type="dsml",
            dsml=Mock()
        )
        
        with self.assertRaises(NotImplementedError):
            llm.invoke(
                user_input="test_input",
                provider="dsml"
            )
    
    @patch.object(LLM, '_get_gssso')
    async def test_ainvoke_with_tools(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.async_inference_with_tools = AsyncMock(return_value="test_response")
        
        result = await self.llm.ainvoke(
            user_input="test_input",
            use_tools=True
        )
        
        self.assertEqual(result, "test_response")
        self.mock_cms.async_inference_with_tools.assert_called_once()
    
    @patch.object(LLM, '_get_gssso')
    async def test_ainvoke_without_tools(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.async_inference = AsyncMock(return_value="test_response")
        
        result = await self.llm.ainvoke(
            user_input="test_input",
            use_tools=False
        )
        
        self.assertEqual(result, "test_response")
        self.mock_cms.async_inference.assert_called_once()
    
    @patch.object(LLM, '_get_gssso')
    def test_stream(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.stream_inference.return_value = "test_response"
        
        result = self.llm.stream(
            user_input="test_input"
        )
        
        self.assertEqual(result, "test_response")
        self.mock_cms.stream_inference.assert_called_once()
    
    @patch.object(LLM, '_get_gssso')
    async def test_astream(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.async_stream_inference = AsyncMock(return_value="test_response")
        
        result = await self.llm.astream(
            user_input="test_input"
        )
        
        self.assertEqual(result, "test_response")
        self.mock_cms.async_stream_inference.assert_called_once()
    
    @patch.object(LLM, '_get_gssso')
    def test_upload_document(self, mock_get_gssso):
        mock_get_gssso.return_value = "test_sso"
        self.mock_cms.upload_document.return_value = "doc123"
        
        result = self.llm.upload_document(
            file_path="test/path.pdf"
        )
        
        self.assertEqual(result, "doc123")
        self.mock_cms.upload_document.assert_called_once()


if __name__ == '__main__':
    unittest.main()