Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import yaml
from models.config import (
Configuration,
Customization,
LLamaStackConfiguration,
UserDataCollection,
ServiceConfiguration,
Expand Down Expand Up @@ -90,5 +91,13 @@ def authentication_configuration(self) -> Optional[AuthenticationConfiguration]:
), "logic error: configuration is not loaded"
return self._configuration.authentication

@property
def customization(self) -> Optional[Customization]:
"""Return customization configuration."""
assert (
self._configuration is not None
), "logic error: configuration is not loaded"
return self._configuration.customization


configuration: AppConfig = AppConfig()
20 changes: 20 additions & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import constants

from utils import checks


class TLSConfiguration(BaseModel):
"""TLS configuration."""
Expand Down Expand Up @@ -123,6 +125,23 @@ def check_authentication_model(self) -> Self:
return self


class Customization(BaseModel):
"""Service customization."""

system_prompt_path: Optional[FilePath] = None
system_prompt: Optional[str] = None

@model_validator(mode="after")
def check_authentication_model(self) -> Self:
"""Load system prompt from file."""
if self.system_prompt_path is not None:
checks.file_check(self.system_prompt_path, "system prompt")
self.system_prompt = checks.get_attribute_from_file(
dict(self), "system_prompt_path"
)
return self


class Configuration(BaseModel):
"""Global service configuration."""

Expand All @@ -134,6 +153,7 @@ class Configuration(BaseModel):
authentication: Optional[AuthenticationConfiguration] = (
AuthenticationConfiguration()
)
customization: Optional[Customization] = None

def dump(self, filename: str = "configuration.json") -> None:
"""Dump actual configuration into JSON file."""
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
"k8s_ca_cert_path": None,
"k8s_cluster_api": None,
},
"customization": {
"system_prompt_path": None,
"system_prompt": None,
},
}

# NOTE(lucasagomes): Configuration must be initialized before importing
Expand Down
1 change: 1 addition & 0 deletions tests/unit/app/endpoints/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_config_endpoint_handler_configuration_loaded(mocker):
"authentication": {
"module": "noop",
},
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/app/endpoints/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_info_endpoint(mocker):
"user_data_collection": {
"feedback_disabled": True,
},
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/app/endpoints/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_models_endpoint_handler_improper_llama_stack_configuration(mocker):
"transcripts_disabled": True,
},
"mcp_servers": [],
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down Expand Up @@ -82,6 +83,7 @@ def test_models_endpoint_handler_configuration_loaded(mocker):
"user_data_collection": {
"feedback_disabled": True,
},
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down Expand Up @@ -114,6 +116,7 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker):
"user_data_collection": {
"feedback_disabled": True,
},
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def setup_configuration():
"transcripts_disabled": True,
},
"mcp_servers": [],
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/models/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def test_configuration_empty_mcp_servers() -> None:
feedback_disabled=True, feedback_storage=None
),
mcp_servers=[],
customization=None,
)
assert cfg is not None
assert cfg.mcp_servers == []
Expand All @@ -270,6 +271,7 @@ def test_configuration_single_mcp_server() -> None:
feedback_disabled=True, feedback_storage=None
),
mcp_servers=[mcp_server],
customization=None,
)
assert cfg is not None
assert len(cfg.mcp_servers) == 1
Expand All @@ -296,6 +298,7 @@ def test_configuration_multiple_mcp_servers() -> None:
feedback_disabled=True, feedback_storage=None
),
mcp_servers=mcp_servers,
customization=None,
)
assert cfg is not None
assert len(cfg.mcp_servers) == 3
Expand All @@ -317,6 +320,7 @@ def test_dump_configuration(tmp_path) -> None:
feedback_disabled=True, feedback_storage=None
),
mcp_servers=[],
customization=None,
)
assert cfg is not None
dump_file = tmp_path / "test.json"
Expand Down Expand Up @@ -370,6 +374,7 @@ def test_dump_configuration(tmp_path) -> None:
"k8s_ca_cert_path": None,
"k8s_cluster_api": None,
},
"customization": None,
}


Expand All @@ -388,6 +393,7 @@ def test_dump_configuration_with_one_mcp_server(tmp_path) -> None:
feedback_disabled=True, feedback_storage=None
),
mcp_servers=mcp_servers,
customization=None,
)
dump_file = tmp_path / "test.json"
cfg.dump(dump_file)
Expand Down Expand Up @@ -442,6 +448,7 @@ def test_dump_configuration_with_one_mcp_server(tmp_path) -> None:
"k8s_ca_cert_path": None,
"k8s_cluster_api": None,
},
"customization": None,
}


Expand All @@ -462,6 +469,7 @@ def test_dump_configuration_with_more_mcp_servers(tmp_path) -> None:
feedback_disabled=True, feedback_storage=None
),
mcp_servers=mcp_servers,
customization=None,
)
dump_file = tmp_path / "test.json"
cfg.dump(dump_file)
Expand Down Expand Up @@ -532,6 +540,7 @@ def test_dump_configuration_with_more_mcp_servers(tmp_path) -> None:
"k8s_ca_cert_path": None,
"k8s_cluster_api": None,
},
"customization": None,
}


Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_init_from_dict() -> None:
"feedback_disabled": True,
},
"mcp_servers": [],
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_init_from_dict_with_mcp_servers() -> None:
"url": "https://api.example.com",
},
],
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down Expand Up @@ -216,6 +218,7 @@ def test_mcp_servers_property_empty() -> None:
"feedback_disabled": True,
},
"mcp_servers": [],
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down Expand Up @@ -251,6 +254,7 @@ def test_mcp_servers_property_with_servers() -> None:
"url": "http://localhost:8080",
},
],
"customization": None,
}
cfg = AppConfig()
cfg.init_from_dict(config_dict)
Expand Down Expand Up @@ -306,3 +310,46 @@ def test_mcp_servers_not_loaded():
AssertionError, match="logic error: configuration is not loaded"
):
cfg.mcp_servers


def test_load_configuration_with_customization(tmpdir) -> None:
"""Test loading configuration from YAML file with customization."""
system_prompt_filename = tmpdir / "system_prompt.txt"
with open(system_prompt_filename, "w") as fout:
fout.write("this is system prompt")

cfg_filename = tmpdir / "config.yaml"
with open(cfg_filename, "w") as fout:
fout.write(
f"""
name: test service
service:
host: localhost
port: 8080
auth_enabled: false
workers: 1
color_log: true
access_log: true
llama_stack:
use_as_library_client: false
url: http://localhost:8321
api_key: test-key
user_data_collection:
feedback_disabled: true
mcp_servers:
- name: filesystem-server
url: http://localhost:3000
- name: git-server
provider_id: custom-git-provider
url: https://git.example.com/mcp
customization:
system_prompt_path: {system_prompt_filename}
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are L#347 to L#355 meant to be here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the configuration is read and system_prompt needs to be initialized from file system_prompt_filename (variable name). The test checks that behaviour.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry, my mistake.. when I looked at the whole file it makes perfect sense.

)

cfg = AppConfig()
cfg.load_configuration(cfg_filename)

assert cfg.customization is not None
assert cfg.customization.system_prompt is not None
assert cfg.customization.system_prompt == "this is system prompt"
6 changes: 6 additions & 0 deletions tests/unit/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def test_register_mcp_servers_empty_list(mocker):
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[],
customization=None,
)
# Call the function
await register_mcp_servers_async(mock_logger, config)
Expand Down Expand Up @@ -80,6 +81,7 @@ async def test_register_mcp_servers_single_server_not_registered(mocker):
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
customization=None,
)

# Call the function
Expand Down Expand Up @@ -122,6 +124,7 @@ async def test_register_mcp_servers_single_server_already_registered(mocker):
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
customization=None,
)

# Call the function
Expand Down Expand Up @@ -167,6 +170,7 @@ async def test_register_mcp_servers_multiple_servers_mixed_registration(mocker):
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=mcp_servers,
customization=None,
)

# Call the function
Expand Down Expand Up @@ -219,6 +223,7 @@ async def test_register_mcp_servers_with_custom_provider(mocker):
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
customization=None,
)

# Call the function
Expand Down Expand Up @@ -267,6 +272,7 @@ async def test_register_mcp_servers_async_with_library_client(mocker):
),
user_data_collection=UserDataCollection(feedback_disabled=True),
mcp_servers=[mcp_server],
customization=None,
)

# Call the async function
Expand Down