Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,21 @@ class CORSConfiguration(BaseModel):
allow_origins: list[str] = [
"*"
] # not AnyHttpUrl: we need to support "*" that is not valid URL
allow_credentials: bool = True
allow_credentials: bool = False
allow_methods: list[str] = ["*"]
allow_headers: list[str] = ["*"]

@model_validator(mode="after")
def check_cors_configuration(self) -> Self:
"""Check CORS configuration."""
# credentials are not allowed with wildcard origins per CORS/Fetch spec.
# see https://fastapi.tiangolo.com/tutorial/cors/
if self.allow_credentials and "*" in self.allow_origins:
raise ValueError(
"Invalid CORS configuration: allow_credentials can not be set to true when "
"allow origins contains '*' wildcard."
"Use explicit origins or disable credential."
)
return self


Expand Down
52 changes: 50 additions & 2 deletions tests/unit/models/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ def test_cors_default_configuration() -> None:
cfg = CORSConfiguration()
assert cfg is not None
assert cfg.allow_origins == ["*"]
assert cfg.allow_credentials is True
assert cfg.allow_credentials is False
assert cfg.allow_methods == ["*"]
assert cfg.allow_headers == ["*"]


def test_cors_custom_configuration() -> None:
def test_cors_custom_configuration_v1() -> None:
"""Test the CORS configuration."""
cfg = CORSConfiguration(
allow_origins=["foo_origin", "bar_origin", "baz_origin"],
Expand All @@ -240,6 +240,54 @@ def test_cors_custom_configuration() -> None:
assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"]


def test_cors_custom_configuration_v2() -> None:
"""Test the CORS configuration."""
cfg = CORSConfiguration(
allow_origins=["foo_origin", "bar_origin", "baz_origin"],
allow_credentials=True,
allow_methods=["foo_method", "bar_method", "baz_method"],
allow_headers=["foo_header", "bar_header", "baz_header"],
)
assert cfg is not None
assert cfg.allow_origins == ["foo_origin", "bar_origin", "baz_origin"]
assert cfg.allow_credentials is True
assert cfg.allow_methods == ["foo_method", "bar_method", "baz_method"]
assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"]


def test_cors_custom_configuration_v3() -> None:
"""Test the CORS configuration."""
cfg = CORSConfiguration(
allow_origins=["*"],
allow_credentials=False,
allow_methods=["foo_method", "bar_method", "baz_method"],
allow_headers=["foo_header", "bar_header", "baz_header"],
)
assert cfg is not None
assert cfg.allow_origins == ["*"]
assert cfg.allow_credentials is False
assert cfg.allow_methods == ["foo_method", "bar_method", "baz_method"]
assert cfg.allow_headers == ["foo_header", "bar_header", "baz_header"]


def test_cors_improper_configuration() -> None:
"""Test the CORS configuration."""
expected = (
"Value error, Invalid CORS configuration: "
+ "allow_credentials can not be set to true when allow origins contains '\\*' wildcard."
+ "Use explicit origins or disable credential."
)

with pytest.raises(ValueError, match=expected):
# allow_credentials can not be true when allow_origins contains '*'
CORSConfiguration(
allow_origins=["*"],
allow_credentials=True,
allow_methods=["foo_method", "bar_method", "baz_method"],
allow_headers=["foo_header", "bar_header", "baz_header"],
)


def test_tls_configuration() -> None:
"""Test the TLS configuration."""
cfg = TLSConfiguration(
Expand Down