In [None]:
"""
Tests for the LLM class.
"""

import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock, call
from typing import List, Dict, Any, Optional, Type, Tuple
import asyncio

# Import your module
from your_module.llm import LLM
from your_module.provider.cms import ConversationManagementService


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

@pytest.fixture
def mock_cms():
    """Create a mock CMS service."""
    cms = Mock(spec=ConversationManagementService)
    cms.inference.return_value = "Sample response"
    cms.inference_with_tools.return_value = "Sample response with tools"
    cms.async_inference = AsyncMock(return_value="Async sample response")
    cms.async_inference_with_tools = AsyncMock(return_value="Async sample response with tools")
    cms.stream_inference.return_value = "Streamed response"
    cms.async_stream_inference = AsyncMock(return_value="Async streamed response")
    cms.upload_document.return_value = "doc123"
    return cms


@pytest.fixture
def mock_dsml():
    """Create a mock DSML service."""
    dsml = Mock()
    return dsml


@pytest.fixture
def mock_utils():
    """Create a mock Utils class."""
    utils = Mock()
    return utils


@pytest.fixture
def mock_gemini_utils():
    """Create a mock GeminiUtils class."""
    gemini_utils = Mock()
    gemini_utils.process_tools.return_value = (["processed_tool"], {"tool_name": lambda: "tool_result"})
    return gemini_utils


@pytest.fixture
def llm_instance(mock_cms):
    """Create an LLM instance with mocked dependencies."""
    return LLM(app_id="test-app", env="dev", cms=mock_cms)


# ---------------------------------------------------------------------------
# Initialization Tests
# ---------------------------------------------------------------------------

def test_init_with_default_values():
    """Test initialization with minimal parameters."""
    llm = LLM(app_id="test-app", env="dev")
    
    assert llm.app_id == "test-app"
    assert llm.env == "dev"
    assert llm.cms is None
    assert llm.dsml is None
    assert llm.provider == "cms"  # Default provider


def test_init_with_all_values(mock_cms, mock_dsml):
    """Test initialization with all parameters."""
    llm = LLM(
        app_id="test-app",
        env="prod",
        cms=mock_cms,
        dsml=mock_dsml,
        provider="dsml"
    )
    
    assert llm.app_id == "test-app"
    assert llm.env == "prod"
    assert llm.cms == mock_cms
    assert llm.dsml == mock_dsml
    assert llm.provider == "dsml"


@patch("your_module.llm.Utils")
@patch("your_module.llm.GeminiUtils")
def test_init_sets_utils(mock_gemini_utils_class, mock_utils_class):
    """Test that __init__ sets up utils correctly."""
    mock_utils_instance = Mock()
    mock_gemini_utils_instance = Mock()
    mock_utils_class.return_value = mock_utils_instance
    mock_gemini_utils_class.return_value = mock_gemini_utils_instance
    
    llm = LLM(app_id="test-app", env="dev")
    
    mock_utils_class.assert_called_once_with(app_id="test-app", env="dev")
    mock_gemini_utils_class.assert_called_once()
    assert llm.utils == mock_utils_instance
    assert llm.gemini_utils == mock_gemini_utils_instance


# ---------------------------------------------------------------------------
# Factory Method Tests
# ---------------------------------------------------------------------------

@patch("your_module.llm.get_preferences")
def test_build_cms_classmethod(mock_get_preferences, mock_gemini_utils):
    """Test the build_cms classmethod."""
    mock_get_preferences.return_value = {"temperature": 0.7}
    
    # Mock the class to test the classmethod
    cls_mock = Mock()
    
    # Create a mock CMS instance to be returned by the CMS constructor
    cms_instance = Mock()
    cls_mock.return_value = cms_instance
    
    # Call the method under test
    result = LLM.build_cms(
        cls_mock,
        app_id="test-app",
        env="dev",
        model_name="test-model",
        tools=["tool1", "tool2"],
        temperature=0.5,
        reasoning_effort="medium",
        log_level="DEBUG",
        custom_param="value"
    )
    
    # Verify preferences were retrieved
    mock_get_preferences.assert_called_once_with(
        inference_model_name="test-model",
        temperature=0.5,
        reasoning_effort="medium",
        custom_param="value"
    )
    
    # Verify CMS was constructed correctly
    cls_mock.assert_called_once_with(
        app_id="test-app",
        env="dev",
        preferences={"temperature": 0.7},
        tools=None,  # Should be processed_tools, but we're mocking
        log_level="DEBUG",
        custom_param="value"
    )
    
    # Verify result is the CMS instance
    assert result == cms_instance


@patch.object(LLM, "build_cms")
def test_init_classmethod(mock_build_cms, mock_gemini_utils):
    """Test the init classmethod."""
    # Setup mock CMS instance
    mock_cms_instance = Mock()
    mock_build_cms.return_value = mock_cms_instance
    
    # Setup mock LLM instance
    mock_llm_instance = Mock()
    cls_mock = Mock(return_value=mock_llm_instance)
    
    # Call the method under test
    result = LLM.init(
        cls_mock,
        app_id="test-app",
        env="dev",
        model_name="test-model",
        tools=["tool1", "tool2"],
        provider_type="cms",
        temperature=0.5,
        reasoning_effort="medium",
        log_level="DEBUG",
        custom_param="value"
    )
    
    # Verify build_cms was called correctly
    mock_build_cms.assert_called_once_with(
        app_id="test-app",
        env="dev",
        model_name="test-model",
        tools=["tool1", "tool2"],
        temperature=0.5,
        reasoning_effort="medium",
        log_level="DEBUG",
        custom_param="value"
    )
    
    # Verify LLM was constructed correctly
    cls_mock.assert_called_once_with(
        app_id="test-app",
        env="dev",
        cms=mock_cms_instance,
        provider_type="cms"
    )
    
    # Verify result is the LLM instance
    assert result == mock_llm_instance


@patch("your_module.llm.GeminiUtils")
@patch.object(LLM, "build_cms")
def test_init_with_tools_processing(mock_build_cms, mock_gemini_utils_class, mock_gemini_utils):
    """Test the init classmethod with tools processing."""
    # Setup mock CMS instance
    mock_cms_instance = Mock()
    mock_build_cms.return_value = mock_cms_instance
    
    # Setup mock for GeminiUtils
    mock_gemini_utils_class.return_value = mock_gemini_utils
    mock_gemini_utils.process_tools.return_value = (
        ["processed_tool1", "processed_tool2"],
        {"tool1": lambda: "result1", "tool2": lambda: "result2"}
    )
    
    # Setup mock LLM instance
    mock_llm_instance = Mock()
    cls_mock = Mock(return_value=mock_llm_instance)
    
    # Define test tools
    tools = [lambda: "tool1", (lambda: "tool2", "param_model", "response_model")]
    
    # Call the method under test
    result = LLM.init(
        cls_mock,
        app_id="test-app",
        env="dev",
        tools=tools
    )
    
    # Verify tools were processed
    mock_gemini_utils.process_tools.assert_called_once_with(tools)
    
    # Verify build_cms was called with processed tools
    mock_build_cms.assert_called_once()
    _, kwargs = mock_build_cms.call_args
    assert kwargs["tools"] == ["processed_tool1", "processed_tool2"]
    
    # Verify tool callables were registered
    assert mock_cms_instance.tool_callables == {
        "tool1": tools[0], 
        "tool2": tools[1][0]
    }


@patch.object(LLM, "build_cms")
def test_init_dsml_provider_not_implemented(mock_build_cms):
    """Test init raises NotImplementedError with dsml provider."""
    cls_mock = Mock()
    
    with pytest.raises(NotImplementedError, match="DSML provider not yet implemented"):
        LLM.init(
            cls_mock,
            app_id="test-app",
            env="dev",
            provider_type="dsml"
        )


@patch.object(LLM, "build_cms")
def test_init_invalid_provider(mock_build_cms):
    """Test init raises ValueError with invalid provider."""
    cls_mock = Mock()
    
    with pytest.raises(ValueError, match="Unknown provider type: invalid"):
        LLM.init(
            cls_mock,
            app_id="test-app",
            env="dev",
            provider_type="invalid"
        )


# ---------------------------------------------------------------------------
# Provider Selection Tests
# ---------------------------------------------------------------------------

def test_get_provider_default(llm_instance):
    """Test _get_provider returns default provider."""
    provider = llm_instance._get_provider()
    
    assert provider == "cms"


def test_get_provider_override(llm_instance):
    """Test _get_provider respects override."""
    # Set default to cms, but override to dsml
    llm_instance.dsml = Mock()  # Need to set this to avoid error
    provider = llm_instance._get_provider("dsml")
    
    assert provider == "dsml"


def test_get_provider_cms_not_initialized():
    """Test _get_provider raises error when CMS is not initialized."""
    llm = LLM(app_id="test-app", env="dev")  # No CMS set
    
    with pytest.raises(ValueError, match="CMS client not initialized"):
        llm._get_provider("cms")


def test_get_provider_dsml_not_initialized():
    """Test _get_provider raises error when DSML is not initialized."""
    llm = LLM(app_id="test-app", env="dev")  # No DSML set
    
    with pytest.raises(ValueError, match="DSML client not initialized"):
        llm._get_provider("dsml")


def test_get_provider_invalid():
    """Test _get_provider raises error for invalid provider."""
    llm = LLM(app_id="test-app", env="dev")
    
    with pytest.raises(ValueError, match="Unknown provider type: invalid"):
        llm._get_provider("invalid")


# ---------------------------------------------------------------------------
# GSSSO Token Tests
# ---------------------------------------------------------------------------

@patch("your_module.llm.gs_auth")
def test_get_gssso_success(mock_gs_auth):
    """Test successful GSSSO token retrieval."""
    mock_gs_auth.get_gssso.return_value = "test-token"
    
    token = LLM._get_gssso()
    
    assert token == "test-token"
    mock_gs_auth.get_gssso.assert_called_once()


@patch("your_module.llm.gs_auth")
def test_get_gssso_failure(mock_gs_auth):
    """Test GSSSO token retrieval failure."""
    mock_gs_auth.get_gssso.side_effect = Exception("Auth error")
    
    with pytest.raises(ValueError, match="Failed to get GSSSO token: Auth error"):
        LLM._get_gssso()


# ---------------------------------------------------------------------------
# Invoke Method Tests
# ---------------------------------------------------------------------------

def test_invoke_with_provided_token(llm_instance, mock_cms):
    """Test invoke with provided token."""
    result = llm_instance.invoke(
        user_input="Hello",
        schema={"type": "object"},
        gssso="provided-token",
        document_id="doc123",
        document_type="pdf"
    )
    
    assert result == "Sample response"
    mock_cms.inference.assert_called_once_with(
        question="Hello",
        gssso="provided-token",
        schema={"type": "object"},
        document_id="doc123",
        document_type="pdf"
    )


@patch.object(LLM, "_get_gssso")
def test_invoke_gets_token_if_not_provided(mock_get_gssso, llm_instance, mock_cms):
    """Test invoke gets token if not provided."""
    mock_get_gssso.return_value = "auto-token"
    
    result = llm_instance.invoke(user_input="Hello")
    
    assert result == "Sample response"
    mock_get_gssso.assert_called_once()
    mock_cms.inference.assert_called_once_with(
        question="Hello",
        gssso="auto-token",
        schema=None,
        document_id=None,
        document_type=None
    )


def test_invoke_with_tools(llm_instance, mock_cms):
    """Test invoke with tools enabled."""
    result = llm_instance.invoke(
        user_input="Hello",
        gssso="token",
        use_tools=True
    )
    
    assert result == "Sample response with tools"
    mock_cms.inference_with_tools.assert_called_once_with(
        question="Hello",
        gssso="token",
        schema=None,
        document_id=None,
        document_type=None
    )


@patch.object(LLM, "_get_provider")
def test_invoke_with_provider_override(mock_get_provider, llm_instance, mock_cms):
    """Test invoke with provider override."""
    mock_get_provider.return_value = "cms"
    
    result = llm_instance.invoke(
        user_input="Hello",
        gssso="token",
        provider="dsml"  # Should be overridden and validated by _get_provider
    )
    
    assert result == "Sample response"
    mock_get_provider.assert_called_once_with("dsml")
    mock_cms.inference.assert_called_once()


def test_invoke_with_dsml_provider():
    """Test invoke raises NotImplementedError with dsml provider."""
    # Create LLM with dsml
    llm = LLM(app_id="test-app", env="dev", dsml=Mock(), provider="dsml")
    
    with pytest.raises(NotImplementedError, match="DSML provider not yet implemented"):
        llm.invoke(user_input="Hello", gssso="token")


# ---------------------------------------------------------------------------
# Async Invoke Method Tests
# ---------------------------------------------------------------------------

@pytest.mark.asyncio
async def test_ainvoke_with_provided_token(llm_instance, mock_cms):
    """Test ainvoke with provided token."""
    result = await llm_instance.ainvoke(
        user_input="Hello",
        schema={"type": "object"},
        gssso="provided-token",
        document_id="doc123",
        document_type="pdf"
    )
    
    assert result == "Async sample response"
    mock_cms.async_inference.assert_called_once_with(
        question="Hello",
        gssso="provided-token",
        schema={"type": "object"},
        document_id="doc123",
        document_type="pdf"
    )


@pytest.mark.asyncio
@patch.object(LLM, "_get_gssso")
async def test_ainvoke_gets_token_if_not_provided(mock_get_gssso, llm_instance, mock_cms):
    """Test ainvoke gets token if not provided."""
    mock_get_gssso.return_value = "auto-token"
    
    result = await llm_instance.ainvoke(user_input="Hello")
    
    assert result == "Async sample response"
    mock_get_gssso.assert_called_once()
    mock_cms.async_inference.assert_called_once_with(
        question="Hello",
        gssso="auto-token",
        schema=None,
        document_id=None,
        document_type=None
    )


@pytest.mark.asyncio
async def test_ainvoke_with_tools(llm_instance, mock_cms):
    """Test ainvoke with tools enabled."""
    result = await llm_instance.ainvoke(
        user_input="Hello",
        gssso="token",
        use_tools=True
    )
    
    assert result == "Async sample response with tools"
    mock_cms.async_inference_with_tools.assert_called_once_with(
        question="Hello",
        gssso="token",
        schema=None,
        document_id=None,
        document_type=None
    )


@pytest.mark.asyncio
@patch.object(LLM, "_get_provider")
async def test_ainvoke_with_provider_override(mock_get_provider, llm_instance, mock_cms):
    """Test ainvoke with provider override."""
    mock_get_provider.return_value = "cms"
    
    result = await llm_instance.ainvoke(
        user_input="Hello",
        gssso="token",
        provider="dsml"  # Should be overridden and validated by _get_provider
    )
    
    assert result == "Async sample response"
    mock_get_provider.assert_called_once_with("dsml")
    mock_cms.async_inference.assert_called_once()


@pytest.mark.asyncio
async def test_ainvoke_with_dsml_provider():
    """Test ainvoke raises NotImplementedError with dsml provider."""
    # Create LLM with dsml
    llm = LLM(app_id="test-app", env="dev", dsml=Mock(), provider="dsml")
    
    with pytest.raises(NotImplementedError, match="DSML provider not yet implemented"):
        await llm.ainvoke(user_input="Hello", gssso="token")


# ---------------------------------------------------------------------------
# Stream Method Tests
# ---------------------------------------------------------------------------

def test_stream(llm_instance, mock_cms):
    """Test stream method."""
    result = llm_instance.stream(
        user_input="Hello",
        schema={"type": "object"},
        gssso="token",
        provider="cms"
    )
    
    assert result == "Streamed response"
    mock_cms.stream_inference.assert_called_once_with(
        question="Hello",
        gssso="token",
        schema={"type": "object"}
    )


@patch.object(LLM, "_get_gssso")
def test_stream_gets_token_if_not_provided(mock_get_gssso, llm_instance, mock_cms):
    """Test stream gets token if not provided."""
    mock_get_gssso.return_value = "auto-token"
    
    result = llm_instance.stream(user_input="Hello")
    
    assert result == "Streamed response"
    mock_get_gssso.assert_called_once()
    mock_cms.stream_inference.assert_called_once_with(
        question="Hello",
        gssso="auto-token",
        schema=None
    )


def test_stream_with_dsml_provider():
    """Test stream raises NotImplementedError with dsml provider."""
    # Create LLM with dsml
    llm = LLM(app_id="test-app", env="dev", dsml=Mock(), provider="dsml")
    
    with pytest.raises(NotImplementedError, match="DSML provider not yet implemented"):
        llm.stream(user_input="Hello", gssso="token")


# ---------------------------------------------------------------------------
# Async Stream Method Tests
# ---------------------------------------------------------------------------

@pytest.mark.asyncio
async def test_astream(llm_instance, mock_cms):
    """Test astream method."""
    result = await llm_instance.astream(
        user_input="Hello",
        schema={"type": "object"},
        gssso="token",
        provider="cms"
    )
    
    assert result == "Async streamed response"
    mock_cms.async_stream_inference.assert_called_once_with(
        question="Hello",
        gssso="token",
        schema={"type": "object"}
    )


@pytest.mark.asyncio
@patch.object(LLM, "_get_gssso")
async def test_astream_gets_token_if_not_provided(mock_get_gssso, llm_instance, mock_cms):
    """Test astream gets token if not provided."""
    mock_get_gssso.return_value = "auto-token"
    
    result = await llm_instance.astream(user_input="Hello")
    
    assert result == "Async streamed response"
    mock_get_gssso.assert_called_once()
    mock_cms.async_stream_inference.assert_called_once_with(
        question="Hello",
        gssso="auto-token",
        schema=None
    )


@pytest.mark.asyncio
async def test_astream_with_dsml_provider():
    """Test astream raises NotImplementedError with dsml provider."""
    # Create LLM with dsml
    llm = LLM(app_id="test-app", env="dev", dsml=Mock(), provider="dsml")
    
    with pytest.raises(NotImplementedError, match="DSML provider not yet implemented"):
        await llm.astream(user_input="Hello", gssso="token")


# ---------------------------------------------------------------------------
# Batch Processing Tests
# ---------------------------------------------------------------------------

@patch.object(LLM, "invoke")
def test_batch(mock_invoke, llm_instance):
    """Test batch processing."""
    # Setup return values for each invocation
    mock_invoke.side_effect = ["Result 1", "Result 2", "Result 3"]
    
    result = llm_instance.batch(
        user_inputs=["Input 1", "Input 2", "Input 3"],
        gssso="token"
    )
    
    assert result == ["Result 1", "Result 2", "Result 3"]
    assert mock_invoke.call_count == 3
    mock_invoke.assert_has_calls([
        call("Input 1", "token"),
        call("Input 2", "token"),
        call("Input 3", "token")
    ])


@pytest.mark.asyncio
@patch.object(LLM, "ainvoke")
async def test_abatch(mock_ainvoke, llm_instance):
    """Test asynchronous batch processing."""
    # Setup return values for each invocation
    mock_ainvoke.side_effect = [
        asyncio.Future(),
        asyncio.Future(),
        asyncio.Future()
    ]
    mock_ainvoke.side_effect[0].set_result("Async Result 1")
    mock_ainvoke.side_effect[1].set_result("Async Result 2")
    mock_ainvoke.side_effect[2].set_result("Async Result 3")
    
    result = await llm_instance.abatch(
        user_inputs=["Input 1", "Input 2", "Input 3"],
        gssso="token"
    )
    
    assert result == ["Async Result 1", "Async Result 2", "Async Result 3"]
    assert mock_ainvoke.call_count == 3
    mock_ainvoke.assert_has_calls([
        call("Input 1", "token"),
        call("Input 2", "token"),
        call("Input 3", "token")
    ])


# ---------------------------------------------------------------------------
# Upload Document Tests
# ---------------------------------------------------------------------------

def test_upload_document(llm_instance, mock_cms):
    """Test upload_document method."""
    result = llm_instance.upload_document(
        file_path="/path/to/file.pdf",
        gssso="token",
        provider="cms"
    )
    
    assert result == "doc123"
    mock_cms.upload_document.assert_called_once_with(
        gssso="token",
        file_path="/path/to/file.pdf"
    )


@patch.object(LLM, "_get_gssso")
def test_upload_document_gets_token_if_not_provided(mock_get_gssso, llm_instance, mock_cms):
    """Test upload_document gets token if not provided."""
    mock_get_gssso.return_value = "auto-token"
    
    result = llm_instance.upload_document(file_path="/path/to/file.pdf")
    
    assert result == "doc123"
    mock_get_gssso.assert_called_once()
    mock_cms.upload_document.assert_called_once_with(
        gssso="auto-token",
        file_path="/path/to/file.pdf"
    )


def test_upload_document_with_dsml_provider():
    """Test upload_document raises NotImplementedError with dsml provider."""
    # Create LLM with dsml
    llm = LLM(app_id="test-app", env="dev", dsml=Mock(), provider="dsml")
    
    with pytest.raises(NotImplementedError, match="DSML provider not yet implemented"):
        llm.upload_document(file_path="/path/to/file.pdf", gssso="token")


if __name__ == "__main__":
    pytest.main(["-v"])