Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ def test_select_model_id_invalid_model(mocker):
)


def test_no_available_models(mocker):
"""Test the select_model_id function with an invalid model."""
mock_client = mocker.Mock()
# empty list of models
mock_client.models.list.return_value = []

query_request = QueryRequest(query="What is OpenStack?", model=None, provider=None)

with pytest.raises(Exception) as exc_info:
select_model_id(mock_client, query_request)

assert "No LLM model found in available models" in str(exc_info.value)


def test_validate_attachments_metadata():
"""Test the validate_attachments_metadata function."""
attachments = [
Expand Down Expand Up @@ -151,7 +165,7 @@ def test_validate_attachments_metadata_invalid_content_type():
)


def test_retrieve_response(mocker):
def test_retrieve_response_no_available_shields(mocker):
"""Test the retrieve_response function."""
mock_agent = mocker.Mock()
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
Expand All @@ -172,3 +186,147 @@ def test_retrieve_response(mocker):
documents=[],
stream=False,
)


def test_retrieve_response_one_available_shield(mocker):
"""Test the retrieve_response function."""

class MockShield:
def __init__(self, identifier):
self.identifier = identifier

def identifier(self):
return self.identifier

mock_agent = mocker.Mock()
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client = mocker.Mock()
mock_client.shields.list.return_value = [MockShield("shield1")]

mocker.patch("app.endpoints.query.Agent", return_value=mock_agent)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"

response = retrieve_response(mock_client, model_id, query_request)

assert response == "LLM answer"
mock_agent.create_turn.assert_called_once_with(
messages=[UserMessage(content="What is OpenStack?", role="user", context=None)],
session_id=mocker.ANY,
documents=[],
stream=False,
)


def test_retrieve_response_two_available_shields(mocker):
"""Test the retrieve_response function."""

class MockShield:
def __init__(self, identifier):
self.identifier = identifier

def identifier(self):
return self.identifier

mock_agent = mocker.Mock()
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client = mocker.Mock()
mock_client.shields.list.return_value = [
MockShield("shield1"),
MockShield("shield2"),
]

mocker.patch("app.endpoints.query.Agent", return_value=mock_agent)

query_request = QueryRequest(query="What is OpenStack?")
model_id = "fake_model_id"

response = retrieve_response(mock_client, model_id, query_request)

assert response == "LLM answer"
mock_agent.create_turn.assert_called_once_with(
messages=[UserMessage(content="What is OpenStack?", role="user", context=None)],
session_id=mocker.ANY,
documents=[],
stream=False,
)


def test_retrieve_response_with_one_attachment(mocker):
"""Test the retrieve_response function."""
mock_agent = mocker.Mock()
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client = mocker.Mock()
mock_client.shields.list.return_value = []

attachments = [
Attachment(
attachment_type="log",
content_type="text/plain",
content="this is attachment",
),
]
mocker.patch("app.endpoints.query.Agent", return_value=mock_agent)

query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
model_id = "fake_model_id"

response = retrieve_response(mock_client, model_id, query_request)

assert response == "LLM answer"
mock_agent.create_turn.assert_called_once_with(
messages=[UserMessage(content="What is OpenStack?", role="user", context=None)],
session_id=mocker.ANY,
stream=False,
documents=[
{
"content": "this is attachment",
"mime_type": "text/plain",
},
],
)


def test_retrieve_response_with_two_attachments(mocker):
"""Test the retrieve_response function."""
mock_agent = mocker.Mock()
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
mock_client = mocker.Mock()
mock_client.shields.list.return_value = []

attachments = [
Attachment(
attachment_type="log",
content_type="text/plain",
content="this is attachment",
),
Attachment(
attachment_type="configuration",
content_type="application/yaml",
content="kind: Pod\n metadata:\n name: private-reg",
),
]
mocker.patch("app.endpoints.query.Agent", return_value=mock_agent)

query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
model_id = "fake_model_id"

response = retrieve_response(mock_client, model_id, query_request)

assert response == "LLM answer"
mock_agent.create_turn.assert_called_once_with(
messages=[UserMessage(content="What is OpenStack?", role="user", context=None)],
session_id=mocker.ANY,
stream=False,
documents=[
{
"content": "this is attachment",
"mime_type": "text/plain",
},
{
"content": "kind: Pod\n" " metadata:\n" " name: private-reg",
"mime_type": "application/yaml",
},
],
)