diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index db9698e7..fcb53faf 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -94,6 +94,20 @@ def test_select_model_id_invalid_model(mocker): ) +def test_no_available_models(mocker): + """Test the select_model_id function with an invalid model.""" + mock_client = mocker.Mock() + # empty list of models + mock_client.models.list.return_value = [] + + query_request = QueryRequest(query="What is OpenStack?", model=None, provider=None) + + with pytest.raises(Exception) as exc_info: + select_model_id(mock_client, query_request) + + assert "No LLM model found in available models" in str(exc_info.value) + + def test_validate_attachments_metadata(): """Test the validate_attachments_metadata function.""" attachments = [ @@ -151,7 +165,7 @@ def test_validate_attachments_metadata_invalid_content_type(): ) -def test_retrieve_response(mocker): +def test_retrieve_response_no_available_shields(mocker): """Test the retrieve_response function.""" mock_agent = mocker.Mock() mock_agent.create_turn.return_value.output_message.content = "LLM answer" @@ -172,3 +186,147 @@ def test_retrieve_response(mocker): documents=[], stream=False, ) + + +def test_retrieve_response_one_available_shield(mocker): + """Test the retrieve_response function.""" + + class MockShield: + def __init__(self, identifier): + self.identifier = identifier + + def identifier(self): + return self.identifier + + mock_agent = mocker.Mock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.Mock() + mock_client.shields.list.return_value = [MockShield("shield1")] + + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + + response = retrieve_response(mock_client, model_id, query_request) + + assert response == "LLM answer" + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], + session_id=mocker.ANY, + documents=[], + stream=False, + ) + + +def test_retrieve_response_two_available_shields(mocker): + """Test the retrieve_response function.""" + + class MockShield: + def __init__(self, identifier): + self.identifier = identifier + + def identifier(self): + return self.identifier + + mock_agent = mocker.Mock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.Mock() + mock_client.shields.list.return_value = [ + MockShield("shield1"), + MockShield("shield2"), + ] + + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + + response = retrieve_response(mock_client, model_id, query_request) + + assert response == "LLM answer" + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], + session_id=mocker.ANY, + documents=[], + stream=False, + ) + + +def test_retrieve_response_with_one_attachment(mocker): + """Test the retrieve_response function.""" + mock_agent = mocker.Mock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.Mock() + mock_client.shields.list.return_value = [] + + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + ] + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + + query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) + model_id = "fake_model_id" + + response = retrieve_response(mock_client, model_id, query_request) + + assert response == "LLM answer" + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], + session_id=mocker.ANY, + stream=False, + documents=[ + { + "content": "this is attachment", + "mime_type": "text/plain", + }, + ], + ) + + +def test_retrieve_response_with_two_attachments(mocker): + """Test the retrieve_response function.""" + mock_agent = mocker.Mock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.Mock() + mock_client.shields.list.return_value = [] + + attachments = [ + Attachment( + attachment_type="log", + content_type="text/plain", + content="this is attachment", + ), + Attachment( + attachment_type="configuration", + content_type="application/yaml", + content="kind: Pod\n metadata:\n name: private-reg", + ), + ] + mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + + query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) + model_id = "fake_model_id" + + response = retrieve_response(mock_client, model_id, query_request) + + assert response == "LLM answer" + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], + session_id=mocker.ANY, + stream=False, + documents=[ + { + "content": "this is attachment", + "mime_type": "text/plain", + }, + { + "content": "kind: Pod\n" " metadata:\n" " name: private-reg", + "mime_type": "application/yaml", + }, + ], + )