Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,230 +1,24 @@
"""Unit tests for functions defined in src/models/config.py."""
"""Unit tests checking ability to dump configuration."""

import json
from pathlib import Path

import pytest
from pathlib import Path

from pydantic import ValidationError

from models.config import (
Configuration,
ModelContextProtocolServer,
LlamaStackConfiguration,
ServiceConfiguration,
UserDataCollection,
TLSConfiguration,
DatabaseConfiguration,
PostgreSQLDatabaseConfiguration,
CORSConfiguration,
ModelContextProtocolServer,
Configuration,
ServiceConfiguration,
InferenceConfiguration,
PostgreSQLDatabaseConfiguration,
DatabaseConfiguration,
TLSConfiguration,
)


def test_inference_constructor() -> None:
"""
Test the InferenceConfiguration constructor with valid
parameters.
"""
# Test with no default provider or model, as they are optional
inference_config = InferenceConfiguration()
assert inference_config is not None
assert inference_config.default_provider is None
assert inference_config.default_model is None

# Test with default provider and model
inference_config = InferenceConfiguration(
default_provider="default_provider",
default_model="default_model",
)
assert inference_config is not None
assert inference_config.default_provider == "default_provider"
assert inference_config.default_model == "default_model"


def test_inference_default_model_missing() -> None:
"""
Test case where only default provider is set, should fail
"""
with pytest.raises(
ValueError,
match="Default model must be specified when default provider is set",
):
InferenceConfiguration(
default_provider="default_provider",
)


def test_inference_default_provider_missing() -> None:
"""
Test case where only default model is set, should fail
"""
with pytest.raises(
ValueError,
match="Default provider must be specified when default model is set",
):
InferenceConfiguration(
default_model="default_model",
)


def test_user_data_collection_feedback_enabled() -> None:
"""Test the UserDataCollection constructor for feedback."""
# correct configuration
cfg = UserDataCollection(feedback_enabled=False, feedback_storage=None)
assert cfg is not None
assert cfg.feedback_enabled is False
assert cfg.feedback_storage is None


def test_user_data_collection_feedback_disabled() -> None:
"""Test the UserDataCollection constructor for feedback."""
# incorrect configuration
with pytest.raises(
ValueError,
match="feedback_storage is required when feedback is enabled",
):
UserDataCollection(feedback_enabled=True, feedback_storage=None)


def test_user_data_collection_transcripts_enabled() -> None:
"""Test the UserDataCollection constructor for transcripts."""
# correct configuration
cfg = UserDataCollection(transcripts_enabled=False, transcripts_storage=None)
assert cfg is not None


def test_user_data_collection_transcripts_disabled() -> None:
"""Test the UserDataCollection constructor for transcripts."""
# incorrect configuration
with pytest.raises(
ValueError,
match="transcripts_storage is required when transcripts is enabled",
):
UserDataCollection(transcripts_enabled=True, transcripts_storage=None)


def test_model_context_protocol_server_constructor() -> None:
"""Test the ModelContextProtocolServer constructor."""
mcp = ModelContextProtocolServer(name="test-server", url="http://localhost:8080")
assert mcp is not None
assert mcp.name == "test-server"
assert mcp.provider_id == "model-context-protocol"
assert mcp.url == "http://localhost:8080"


def test_model_context_protocol_server_custom_provider() -> None:
"""Test the ModelContextProtocolServer constructor with custom provider."""
mcp = ModelContextProtocolServer(
name="custom-server",
provider_id="custom-provider",
url="https://api.example.com",
)
assert mcp is not None
assert mcp.name == "custom-server"
assert mcp.provider_id == "custom-provider"
assert mcp.url == "https://api.example.com"


def test_model_context_protocol_server_required_fields() -> None:
"""Test that ModelContextProtocolServer requires name and url."""

with pytest.raises(ValidationError):
ModelContextProtocolServer() # pyright: ignore

with pytest.raises(ValidationError):
ModelContextProtocolServer(name="test-server") # pyright: ignore

with pytest.raises(ValidationError):
ModelContextProtocolServer(url="http://localhost:8080") # pyright: ignore


def test_configuration_empty_mcp_servers() -> None:
"""
Test that a Configuration object can be created with an empty
list of MCP servers.

Verifies that the Configuration instance is constructed
successfully and that the mcp_servers attribute is empty.
"""
cfg = Configuration(
name="test_name",
service=ServiceConfiguration(),
llama_stack=LlamaStackConfiguration(
use_as_library_client=True,
library_client_config_path="tests/configuration/run.yaml",
),
user_data_collection=UserDataCollection(
feedback_enabled=False, feedback_storage=None
),
mcp_servers=[],
customization=None,
)
assert cfg is not None
assert not cfg.mcp_servers


def test_configuration_single_mcp_server() -> None:
"""
Test that a Configuration object can be created with a single
MCP server and verifies its properties.
"""
mcp_server = ModelContextProtocolServer(
name="test-server", url="http://localhost:8080"
)
cfg = Configuration(
name="test_name",
service=ServiceConfiguration(),
llama_stack=LlamaStackConfiguration(
use_as_library_client=True,
library_client_config_path="tests/configuration/run.yaml",
),
user_data_collection=UserDataCollection(
feedback_enabled=False, feedback_storage=None
),
mcp_servers=[mcp_server],
customization=None,
)
assert cfg is not None
assert len(cfg.mcp_servers) == 1
assert cfg.mcp_servers[0].name == "test-server"
assert cfg.mcp_servers[0].url == "http://localhost:8080"


def test_configuration_multiple_mcp_servers() -> None:
"""
Verify that the Configuration object correctly handles multiple
ModelContextProtocolServer instances in its mcp_servers list,
including custom provider IDs.
"""
mcp_servers = [
ModelContextProtocolServer(name="server1", url="http://localhost:8080"),
ModelContextProtocolServer(
name="server2", url="http://localhost:8081", provider_id="custom-provider"
),
ModelContextProtocolServer(name="server3", url="https://api.example.com"),
]
cfg = Configuration(
name="test_name",
service=ServiceConfiguration(),
llama_stack=LlamaStackConfiguration(
use_as_library_client=True,
library_client_config_path="tests/configuration/run.yaml",
),
user_data_collection=UserDataCollection(
feedback_enabled=False, feedback_storage=None
),
mcp_servers=mcp_servers,
customization=None,
)
assert cfg is not None
assert len(cfg.mcp_servers) == 3
assert cfg.mcp_servers[0].name == "server1"
assert cfg.mcp_servers[1].name == "server2"
assert cfg.mcp_servers[1].provider_id == "custom-provider"
assert cfg.mcp_servers[2].name == "server3"


def test_dump_configuration(tmp_path) -> None:
"""
Test that the Configuration object can be serialized to a JSON file and
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/models/config/test_inference_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Unit tests for InferenceConfiguration model."""

import pytest

from models.config import InferenceConfiguration


def test_inference_constructor() -> None:
"""
Test the InferenceConfiguration constructor with valid
parameters.
"""
# Test with no default provider or model, as they are optional
inference_config = InferenceConfiguration()
assert inference_config is not None
assert inference_config.default_provider is None
assert inference_config.default_model is None

# Test with default provider and model
inference_config = InferenceConfiguration(
default_provider="default_provider",
default_model="default_model",
)
assert inference_config is not None
assert inference_config.default_provider == "default_provider"
assert inference_config.default_model == "default_model"


def test_inference_default_model_missing() -> None:
"""
Test case where only default provider is set, should fail
"""
with pytest.raises(
ValueError,
match="Default model must be specified when default provider is set",
):
InferenceConfiguration(
default_provider="default_provider",
)


def test_inference_default_provider_missing() -> None:
"""
Test case where only default model is set, should fail
"""
with pytest.raises(
ValueError,
match="Default provider must be specified when default model is set",
):
InferenceConfiguration(
default_model="default_model",
)
Loading