diff --git a/docs/config.png b/docs/config.png index e8dbe85f..ae4b19eb 100644 Binary files a/docs/config.png and b/docs/config.png differ diff --git a/docs/config.puml b/docs/config.puml index 96529326..191710c8 100644 --- a/docs/config.puml +++ b/docs/config.puml @@ -19,6 +19,14 @@ class "AuthenticationConfiguration" as src.models.config.AuthenticationConfigura class "AuthorizationConfiguration" as src.models.config.AuthorizationConfiguration { access_rules : Optional[list[AccessRule]] } +class "ByokRag" as src.models.config.ByokRag { + db_path : Annotated + embedding_dimension : Annotated + embedding_model : Annotated + rag_id : Annotated + rag_type : Annotated + vector_db_id : Annotated +} class "CORSConfiguration" as src.models.config.CORSConfiguration { allow_credentials : bool allow_headers : list[str] @@ -29,6 +37,7 @@ class "CORSConfiguration" as src.models.config.CORSConfiguration { class "Configuration" as src.models.config.Configuration { authentication : Optional[AuthenticationConfiguration] authorization : Optional[AuthorizationConfiguration] + byok_rag : Optional[list[ByokRag]] conversation_cache : Optional[ConversationCacheConfiguration] customization : Optional[Customization] database : Optional[DatabaseConfiguration] @@ -155,6 +164,7 @@ class "UserDataCollection" as src.models.config.UserDataCollection { src.models.config.AccessRule --|> src.models.config.ConfigurationBase src.models.config.AuthenticationConfiguration --|> src.models.config.ConfigurationBase src.models.config.AuthorizationConfiguration --|> src.models.config.ConfigurationBase +src.models.config.ByokRag --|> src.models.config.ConfigurationBase src.models.config.CORSConfiguration --|> src.models.config.ConfigurationBase src.models.config.Configuration --|> src.models.config.ConfigurationBase src.models.config.ConversationCacheConfiguration --|> src.models.config.ConfigurationBase diff --git a/docs/config.svg b/docs/config.svg index 58ffc412..5c291608 100644 --- a/docs/config.svg +++ b/docs/config.svg @@ -1,463 +1,484 @@ - + - - - - AccessRule - - actions : list[Action] - role : str - + + + + AccessRule + + actions : list[Action] + role : str + - - - - Action - - name - + + + + Action + + name + - - - - AuthenticationConfiguration - - jwk_config : Optional[JwkConfiguration] - jwk_configuration - k8s_ca_cert_path : Optional[FilePath] - k8s_cluster_api : Optional[AnyHttpUrl] - module : str - skip_tls_verification : bool - - check_authentication_model() -> Self + + + + AuthenticationConfiguration + + jwk_config : Optional[JwkConfiguration] + jwk_configuration + k8s_ca_cert_path : Optional[FilePath] + k8s_cluster_api : Optional[AnyHttpUrl] + module : str + skip_tls_verification : bool + + check_authentication_model() -> Self - - - - AuthorizationConfiguration - - access_rules : Optional[list[AccessRule]] - + + + + AuthorizationConfiguration + + access_rules : Optional[list[AccessRule]] + + + + + + + + ByokRag + + db_path : Annotated + embedding_dimension : Annotated + embedding_model : Annotated + rag_id : Annotated + rag_type : Annotated + vector_db_id : Annotated + - - - - CORSConfiguration - - allow_credentials : bool - allow_headers : list[str] - allow_methods : list[str] - allow_origins : list[str] - - check_cors_configuration() -> Self + + + + CORSConfiguration + + allow_credentials : bool + allow_headers : list[str] + allow_methods : list[str] + allow_origins : list[str] + + check_cors_configuration() -> Self - - - - Configuration - - authentication : Optional[AuthenticationConfiguration] - authorization : Optional[AuthorizationConfiguration] - conversation_cache : Optional[ConversationCacheConfiguration] - customization : Optional[Customization] - database : Optional[DatabaseConfiguration] - inference : Optional[InferenceConfiguration] - llama_stack - mcp_servers : Optional[list[ModelContextProtocolServer]] - name : str - service - user_data_collection - - dump(filename: str) -> None + + + + Configuration + + authentication : Optional[AuthenticationConfiguration] + authorization : Optional[AuthorizationConfiguration] + byok_rag : Optional[list[ByokRag]] + conversation_cache : Optional[ConversationCacheConfiguration] + customization : Optional[Customization] + database : Optional[DatabaseConfiguration] + inference : Optional[InferenceConfiguration] + llama_stack + mcp_servers : Optional[list[ModelContextProtocolServer]] + name : str + service + user_data_collection + + dump(filename: str) -> None - - - - ConfigurationBase - - model_config - + + + + ConfigurationBase + + model_config + - - - - ConversationCacheConfiguration - - memory : Optional[InMemoryCacheConfig] - postgres : Optional[PostgreSQLDatabaseConfiguration] - sqlite : Optional[SQLiteDatabaseConfiguration] - type : Literal['noop', 'memory', 'sqlite', 'postgres'] | None - - check_cache_configuration() -> Self + + + + ConversationCacheConfiguration + + memory : Optional[InMemoryCacheConfig] + postgres : Optional[PostgreSQLDatabaseConfiguration] + sqlite : Optional[SQLiteDatabaseConfiguration] + type : Literal['noop', 'memory', 'sqlite', 'postgres'] | None + + check_cache_configuration() -> Self - - - - CustomProfile - - path : str - prompts : Optional[dict[str, str]] - - get_prompts() -> dict[str, str] + + + + CustomProfile + + path : str + prompts : Optional[dict[str, str]] + + get_prompts() -> dict[str, str] - - - - Customization - - custom_profile : Optional[CustomProfile] - disable_query_system_prompt : bool - profile_path : Optional[str] - system_prompt : Optional[str] - system_prompt_path : Optional[FilePath] - - check_customization_model() -> Self + + + + Customization + + custom_profile : Optional[CustomProfile] + disable_query_system_prompt : bool + profile_path : Optional[str] + system_prompt : Optional[str] + system_prompt_path : Optional[FilePath] + + check_customization_model() -> Self - - - - DatabaseConfiguration - - config - db_type - postgres : Optional[PostgreSQLDatabaseConfiguration] - sqlite : Optional[SQLiteDatabaseConfiguration] - - check_database_configuration() -> Self + + + + DatabaseConfiguration + + config + db_type + postgres : Optional[PostgreSQLDatabaseConfiguration] + sqlite : Optional[SQLiteDatabaseConfiguration] + + check_database_configuration() -> Self - - - - InMemoryCacheConfig - - max_entries : Annotated - + + + + InMemoryCacheConfig + + max_entries : Annotated + - - - - InferenceConfiguration - - default_model : Optional[str] - default_provider : Optional[str] - - check_default_model_and_provider() -> Self + + + + InferenceConfiguration + + default_model : Optional[str] + default_provider : Optional[str] + + check_default_model_and_provider() -> Self - - - - JsonPathOperator - - name - + + + + JsonPathOperator + + name + - - - - JwkConfiguration - - jwt_configuration : Optional[JwtConfiguration] - url : AnyHttpUrl - + + + + JwkConfiguration + + jwt_configuration : Optional[JwtConfiguration] + url : AnyHttpUrl + - - - - JwtConfiguration - - role_rules : Optional[list[JwtRoleRule]] - user_id_claim : str - username_claim : str - + + + + JwtConfiguration + + role_rules : Optional[list[JwtRoleRule]] + user_id_claim : str + username_claim : str + - - - - JwtRoleRule - - compiled_regex - jsonpath : str - negate : bool - operator - roles : list[str] - value : Any - - check_jsonpath() -> Self - check_regex_pattern() -> Self - check_roles() -> Self + + + + JwtRoleRule + + compiled_regex + jsonpath : str + negate : bool + operator + roles : list[str] + value : Any + + check_jsonpath() -> Self + check_regex_pattern() -> Self + check_roles() -> Self - - - - LlamaStackConfiguration - - api_key : Optional[SecretStr] - library_client_config_path : Optional[str] - url : Optional[str] - use_as_library_client : Optional[bool] - - check_llama_stack_model() -> Self + + + + LlamaStackConfiguration + + api_key : Optional[SecretStr] + library_client_config_path : Optional[str] + url : Optional[str] + use_as_library_client : Optional[bool] + + check_llama_stack_model() -> Self - - - - ModelContextProtocolServer - - name : str - provider_id : str - url : str - + + + + ModelContextProtocolServer + + name : str + provider_id : str + url : str + - - - - PostgreSQLDatabaseConfiguration - - ca_cert_path : Optional[FilePath] - db : str - gss_encmode : str - host : str - namespace : Optional[str] - password : SecretStr - port : Annotated - ssl_mode : str - user : str - - check_postgres_configuration() -> Self + + + + PostgreSQLDatabaseConfiguration + + ca_cert_path : Optional[FilePath] + db : str + gss_encmode : str + host : str + namespace : Optional[str] + password : SecretStr + port : Annotated + ssl_mode : str + user : str + + check_postgres_configuration() -> Self - - - - SQLiteDatabaseConfiguration - - db_path : str - + + + + SQLiteDatabaseConfiguration + + db_path : str + - - - - ServiceConfiguration - - access_log : bool - auth_enabled : bool - color_log : bool - cors : Optional[CORSConfiguration] - host : str - port : Annotated - tls_config : Optional[TLSConfiguration] - workers : Annotated - - check_service_configuration() -> Self + + + + ServiceConfiguration + + access_log : bool + auth_enabled : bool + color_log : bool + cors : Optional[CORSConfiguration] + host : str + port : Annotated + tls_config : Optional[TLSConfiguration] + workers : Annotated + + check_service_configuration() -> Self - - - - TLSConfiguration - - tls_certificate_path : Optional[FilePath] - tls_key_password : Optional[FilePath] - tls_key_path : Optional[FilePath] - - check_tls_configuration() -> Self + + + + TLSConfiguration + + tls_certificate_path : Optional[FilePath] + tls_key_password : Optional[FilePath] + tls_key_path : Optional[FilePath] + + check_tls_configuration() -> Self - - - - UserDataCollection - - feedback_enabled : bool - feedback_storage : Optional[str] - transcripts_enabled : bool - transcripts_storage : Optional[str] - - check_storage_location_is_set_when_needed() -> Self + + + + UserDataCollection + + feedback_enabled : bool + feedback_storage : Optional[str] + transcripts_enabled : bool + transcripts_storage : Optional[str] + + check_storage_location_is_set_when_needed() -> Self - - + + - - + + - - + + + + + + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - - custom_profile + + + custom_profile - - - operator + + + operator - - - llama_stack + + + llama_stack - - - sqlite + + + sqlite - - - service + + + service - - - user_data_collection + + + user_data_collection - + diff --git a/src/constants.py b/src/constants.py index 4d4b3237..ead6d259 100644 --- a/src/constants.py +++ b/src/constants.py @@ -127,3 +127,16 @@ CACHE_TYPE_SQLITE = "sqlite" CACHE_TYPE_POSTGRES = "postgres" CACHE_TYPE_NOOP = "noop" + +# BYOK RAG +# Default RAG type for bring-your-own-knowledge RAG configurations, that type +# needs to be supported by Llama Stack +DEFAULT_RAG_TYPE = "inline::faiss" + +# Default sentence transformer model for embedding generation, that type needs +# to be supported by Llama Stack and configured properly in providers and +# models sections +DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" + +# Default embedding vector dimension for the sentence transformer model +DEFAULT_EMBEDDING_DIMENSION = 768 diff --git a/src/models/config.py b/src/models/config.py index a09e055a..8f68e29e 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -13,6 +13,7 @@ ConfigDict, Field, model_validator, + constr, FilePath, AnyHttpUrl, PositiveInt, @@ -545,6 +546,19 @@ def check_cache_configuration(self) -> Self: return self +class ByokRag(ConfigurationBase): + """BYOK RAG configuration.""" + + rag_id: constr(min_length=1) # type:ignore + rag_type: constr(min_length=1) = constants.DEFAULT_RAG_TYPE # type:ignore + embedding_model: constr(min_length=1) = ( # type:ignore + constants.DEFAULT_EMBEDDING_MODEL + ) + embedding_dimension: PositiveInt = constants.DEFAULT_EMBEDDING_DIMENSION + vector_db_id: constr(min_length=1) # type:ignore + db_path: FilePath + + class Configuration(ConfigurationBase): """Global service configuration.""" @@ -563,6 +577,7 @@ class Configuration(ConfigurationBase): conversation_cache: ConversationCacheConfiguration = Field( default_factory=ConversationCacheConfiguration ) + byok_rag: list[ByokRag] = Field(default_factory=list) def dump(self, filename: str = "configuration.json") -> None: """Dump actual configuration into JSON file.""" diff --git a/tests/configuration/rag.txt b/tests/configuration/rag.txt new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/models/config/test_byok_rag.py b/tests/unit/models/config/test_byok_rag.py new file mode 100644 index 00000000..11db290c --- /dev/null +++ b/tests/unit/models/config/test_byok_rag.py @@ -0,0 +1,130 @@ +"""Unit tests for ByokRag model.""" + +from pathlib import Path + +import pytest + +from pydantic import ValidationError + +from models.config import ByokRag + +from constants import ( + DEFAULT_RAG_TYPE, + DEFAULT_EMBEDDING_MODEL, + DEFAULT_EMBEDDING_DIMENSION, +) + + +def test_byok_rag_configuration_default_values() -> None: + """Test the ByokRag constructor.""" + + byok_rag = ByokRag( + rag_id="rag_id", + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + ) + assert byok_rag is not None + assert byok_rag.rag_id == "rag_id" + assert byok_rag.rag_type == DEFAULT_RAG_TYPE + assert byok_rag.embedding_model == DEFAULT_EMBEDDING_MODEL + assert byok_rag.embedding_dimension == DEFAULT_EMBEDDING_DIMENSION + assert byok_rag.vector_db_id == "vector_db_id" + assert byok_rag.db_path == Path("tests/configuration/rag.txt") + + +def test_byok_rag_configuration_nondefault_values() -> None: + """Test the ByokRag constructor.""" + + byok_rag = ByokRag( + rag_id="rag_id", + rag_type="rag_type", + embedding_model="embedding_model", + embedding_dimension=1024, + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + ) + assert byok_rag is not None + assert byok_rag.rag_id == "rag_id" + assert byok_rag.rag_type == "rag_type" + assert byok_rag.embedding_model == "embedding_model" + assert byok_rag.embedding_dimension == 1024 + assert byok_rag.vector_db_id == "vector_db_id" + assert byok_rag.db_path == Path("tests/configuration/rag.txt") + + +def test_byok_rag_configuration_wrong_dimension() -> None: + """Test the ByokRag constructor.""" + + with pytest.raises(ValidationError, match="should be greater than 0"): + _ = ByokRag( + rag_id="rag_id", + rag_type="rag_type", + embedding_model="embedding_model", + embedding_dimension=-1024, + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + ) + + +def test_byok_rag_configuration_empty_rag_id() -> None: + """Test the ByokRag constructor.""" + + with pytest.raises( + ValidationError, match="String should have at least 1 character" + ): + _ = ByokRag( + rag_id="", + rag_type="rag_type", + embedding_model="embedding_model", + embedding_dimension=1024, + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + ) + + +def test_byok_rag_configuration_empty_rag_type() -> None: + """Test the ByokRag constructor.""" + + with pytest.raises( + ValidationError, match="String should have at least 1 character" + ): + _ = ByokRag( + rag_id="rag_id", + rag_type="", + embedding_model="embedding_model", + embedding_dimension=1024, + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + ) + + +def test_byok_rag_configuration_empty_embedding_model() -> None: + """Test the ByokRag constructor.""" + + with pytest.raises( + ValidationError, match="String should have at least 1 character" + ): + _ = ByokRag( + rag_id="rag_id", + rag_type="rag_type", + embedding_model="", + embedding_dimension=1024, + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + ) + + +def test_byok_rag_configuration_empty_vector_db_id() -> None: + """Test the ByokRag constructor.""" + + with pytest.raises( + ValidationError, match="String should have at least 1 character" + ): + _ = ByokRag( + rag_id="rag_id", + rag_type="rag_type", + embedding_model="embedding_model", + embedding_dimension=1024, + vector_db_id="", + db_path="tests/configuration/rag.txt", + ) diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 03990b01..4718e5b1 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -88,6 +88,7 @@ def test_dump_configuration(tmp_path) -> None: assert "customization" in content assert "inference" in content assert "database" in content + assert "byok_rag" in content # check the whole deserialized JSON file content assert content == { @@ -169,6 +170,7 @@ def test_dump_configuration(tmp_path) -> None: "sqlite": None, "type": None, }, + "byok_rag": [], }