diff --git a/docs/config.png b/docs/config.png index 770f8d5b..0c60626b 100644 Binary files a/docs/config.png and b/docs/config.png differ diff --git a/docs/config.puml b/docs/config.puml index d01f90f7..636d2268 100644 --- a/docs/config.puml +++ b/docs/config.puml @@ -9,6 +9,13 @@ class "AuthenticationConfiguration" as src.models.config.AuthenticationConfigura skip_tls_verification : bool check_authentication_model() -> Self } +class "CORSConfiguration" as src.models.config.CORSConfiguration { + allow_credentials : bool + allow_headers : list[str] + allow_methods : list[str] + allow_origins : list[str] + check_cors_configuration() -> Self +} class "Configuration" as src.models.config.Configuration { authentication customization : Optional[Customization] @@ -78,6 +85,7 @@ class "ServiceConfiguration" as src.models.config.ServiceConfiguration { access_log : bool auth_enabled : bool color_log : bool + cors host : str port : int tls_config @@ -98,6 +106,7 @@ class "UserDataCollection" as src.models.config.UserDataCollection { check_storage_location_is_set_when_needed() -> Self } src.models.config.AuthenticationConfiguration --* src.models.config.Configuration : authentication +src.models.config.CORSConfiguration --* src.models.config.ServiceConfiguration : cors src.models.config.DatabaseConfiguration --* src.models.config.Configuration : database src.models.config.InferenceConfiguration --* src.models.config.Configuration : inference src.models.config.JwtConfiguration --* src.models.config.JwkConfiguration : jwt_configuration diff --git a/src/app/main.py b/src/app/main.py index 48398966..3b9830fd 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -40,12 +40,14 @@ ], ) +cors = configuration.service_configuration.cors + app.add_middleware( CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_origins=cors.allow_origins, + allow_credentials=cors.allow_credentials, + allow_methods=cors.allow_methods, + allow_headers=cors.allow_headers, ) diff --git a/src/models/config.py b/src/models/config.py index 53b58f9f..8acf9579 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -24,6 +24,22 @@ def check_tls_configuration(self) -> Self: return self +class CORSConfiguration(BaseModel): + """CORS configuration.""" + + allow_origins: list[str] = [ + "*" + ] # not AnyHttpUrl: we need to support "*" that is not valid URL + allow_credentials: bool = True + allow_methods: list[str] = ["*"] + allow_headers: list[str] = ["*"] + + @model_validator(mode="after") + def check_cors_configuration(self) -> Self: + """Check CORS configuration.""" + return self + + class SQLiteDatabaseConfiguration(BaseModel): """SQLite database configuration.""" @@ -106,6 +122,7 @@ class ServiceConfiguration(BaseModel): color_log: bool = True access_log: bool = True tls_config: TLSConfiguration = TLSConfiguration() + cors: CORSConfiguration = CORSConfiguration() @model_validator(mode="after") def check_service_configuration(self) -> Self: diff --git a/tests/configuration/lightspeed-stack.yaml b/tests/configuration/lightspeed-stack.yaml index 1cf9565c..d2b4ab1f 100644 --- a/tests/configuration/lightspeed-stack.yaml +++ b/tests/configuration/lightspeed-stack.yaml @@ -6,6 +6,20 @@ service: workers: 1 color_log: true access_log: true + cors: + allow_origins: + - foo_origin + - bar_origin + - baz_origin + allow_credentials: false + allow_methods: + - foo_method + - bar_method + - baz_method + allow_headers: + - foo_header + - bar_header + - baz_header llama_stack: # Uses a remote llama-stack service # The instance would have already been started with a llama-stack-run.yaml file diff --git a/tests/integration/test_configuration.py b/tests/integration/test_configuration.py index 54b1f0cd..22cf7b34 100644 --- a/tests/integration/test_configuration.py +++ b/tests/integration/test_configuration.py @@ -47,6 +47,13 @@ def test_loading_proper_configuration(configuration_filename: str) -> None: assert svc_config.color_log is True assert svc_config.access_log is True + # check 'service.cors' section + cors_config = cfg.service_configuration.cors + assert cors_config.allow_origins == ["foo_origin", "bar_origin", "baz_origin"] + assert cors_config.allow_credentials is False + assert cors_config.allow_methods == ["foo_method", "bar_method", "baz_method"] + assert cors_config.allow_headers == ["foo_header", "bar_header", "baz_header"] + # check 'llama_stack' section ls_config = cfg.llama_stack_configuration assert ls_config.use_as_library_client is False diff --git a/tests/unit/models/test_config.py b/tests/unit/models/test_config.py index a94741a9..f6302ef9 100644 --- a/tests/unit/models/test_config.py +++ b/tests/unit/models/test_config.py @@ -21,6 +21,7 @@ ServiceConfiguration, UserDataCollection, TLSConfiguration, + CORSConfiguration, ModelContextProtocolServer, InferenceConfiguration, ) @@ -214,6 +215,31 @@ def test_user_data_collection_transcripts_disabled() -> None: UserDataCollection(transcripts_enabled=True, transcripts_storage=None) +def test_cors_default_configuration() -> None: + """Test the CORS configuration.""" + cfg = CORSConfiguration() + assert cfg is not None + assert cfg.allow_origins == ["*"] + assert cfg.allow_credentials is True + assert cfg.allow_methods == ["*"] + assert cfg.allow_headers == ["*"] + + +def test_cors_custom_configuration() -> None: + """Test the CORS configuration.""" + cfg = CORSConfiguration( + allow_origins=["foo_origin", "bar_origin", "baz_origin"], + allow_credentials=False, + allow_methods=["foo_method", "bar_method", "baz_method"], + allow_headers=["foo_header", "bar_header", "baz_header"], + ) + assert cfg is not None + assert cfg.allow_origins == ["foo_origin", "bar_origin", "baz_origin"] + assert cfg.allow_credentials is False + assert cfg.allow_methods == ["foo_method", "bar_method", "baz_method"] + assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"] + + def test_tls_configuration() -> None: """Test the TLS configuration.""" cfg = TLSConfiguration( @@ -437,7 +463,13 @@ def test_dump_configuration(tmp_path) -> None: tls_certificate_path=Path("tests/configuration/server.crt"), tls_key_path=Path("tests/configuration/server.key"), tls_key_password=Path("tests/configuration/password"), - ) + ), + cors=CORSConfiguration( + allow_origins=["foo_origin", "bar_origin", "baz_origin"], + allow_credentials=False, + allow_methods=["foo_method", "bar_method", "baz_method"], + allow_headers=["foo_header", "bar_header", "baz_header"], + ), ), llama_stack=LlamaStackConfiguration( use_as_library_client=True, @@ -488,6 +520,24 @@ def test_dump_configuration(tmp_path) -> None: "tls_key_password": "tests/configuration/password", "tls_key_path": "tests/configuration/server.key", }, + "cors": { + "allow_credentials": False, + "allow_headers": [ + "foo_header", + "bar_header", + "baz_header", + ], + "allow_methods": [ + "foo_method", + "bar_method", + "baz_method", + ], + "allow_origins": [ + "foo_origin", + "bar_origin", + "baz_origin", + ], + }, }, "llama_stack": { "url": None,