From 58a4044dcfa2c76abd2f019d33a618f04fb9c565 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Sun, 27 Aug 2023 15:16:45 -0400 Subject: [PATCH] fix pg settings type and add tests --- CHANGES.md | 3 +++ tests/test_settings.py | 50 ++++++++++++++++++++++++++++++++++++++++++ tipg/settings.py | 12 +++++----- 3 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 tests/test_settings.py diff --git a/CHANGES.md b/CHANGES.md index 2e362ba0..289898cf 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -12,6 +12,9 @@ Note: Minor version `0.X.0` update might break the API, It's recommended to pin - forward `catalog_dependency` in `OGCFeaturesFactory` and `OGCTilesFactory` when using `Endpoints` factory - allow Factory's prefix with path parameter +- changed `database_url` type in `PostgresSettings` to always be of `pydantic.PostgresDsn` type +- `postgres_port` type in `PostgresSettings` to be of `integer` type +- remove additional `/` prefix for dbname when constructing the database url from individual parameters ### changed diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..9f20a033 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,50 @@ +"""test tipg settings classes.""" + +import pytest +from pydantic import ValidationError + +from tipg.settings import PostgresSettings + + +def test_pg_settings(monkeypatch): + """test PostgresSettings class.""" + # Makes sure we don't have any pg env set + monkeypatch.delenv("DATABASE_URL", raising=False) + monkeypatch.delenv("POSTGRES_USER", raising=False) + monkeypatch.delenv("POSTGRES_PASS", raising=False) + monkeypatch.delenv("POSTGRES_HOST", raising=False) + monkeypatch.delenv("POSTGRES_PORT", raising=False) + monkeypatch.delenv("POSTGRES_DBNAME", raising=False) + + # Should raises a validation error if no env or parameters is passed + with pytest.raises(ValidationError): + # we use `_env_file=None` to make sure pydantic do not use any `.env` files in local environment + PostgresSettings(_env_file=None) + + settings = PostgresSettings( + postgres_user="user", + postgres_pass="secret", + postgres_host="0.0.0.0", + postgres_port=8888, + postgres_dbname="db", + _env_file=None, + ) + assert str(settings.database_url) == "postgresql://user:secret@0.0.0.0:8888/db" + + # Make sure pydantic will cast the port to integer + settings = PostgresSettings( + postgres_user="user", + postgres_pass="secret", + postgres_host="0.0.0.0", + postgres_port="8888", + postgres_dbname="db", + _env_file=None, + ) + assert str(settings.database_url) == "postgresql://user:secret@0.0.0.0:8888/db" + assert settings.postgres_port == 8888 + + settings = PostgresSettings( + database_url="postgresql://user:secret@0.0.0.0:8888/db", _env_file=None + ) + assert str(settings.database_url) == "postgresql://user:secret@0.0.0.0:8888/db" + assert not settings.postgres_port diff --git a/tipg/settings.py b/tipg/settings.py index d554d520..c8fc88a3 100644 --- a/tipg/settings.py +++ b/tipg/settings.py @@ -1,7 +1,7 @@ """tipg config.""" import pathlib -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from pydantic import ( DirectoryPath, @@ -117,7 +117,7 @@ class PostgresSettings(BaseSettings): postgres_user: Optional[str] = None postgres_pass: Optional[str] = None postgres_host: Optional[str] = None - postgres_port: Optional[str] = None + postgres_port: Optional[int] = None postgres_dbname: Optional[str] = None database_url: Optional[PostgresDsn] = None @@ -131,10 +131,12 @@ class PostgresSettings(BaseSettings): # https://github.com/tiangolo/full-stack-fastapi-postgresql/blob/master/%7B%7Bcookiecutter.project_slug%7D%7D/backend/app/app/core/config.py#L42 @field_validator("database_url", mode="before") - def assemble_db_connection(cls, v: Optional[str], info: FieldValidationInfo) -> Any: + def assemble_db_connection( + cls, v: Optional[str], info: FieldValidationInfo + ) -> PostgresDsn: """Validate db url settings.""" if isinstance(v, str): - return v + return PostgresDsn(v) return PostgresDsn.build( scheme="postgresql", @@ -142,7 +144,7 @@ def assemble_db_connection(cls, v: Optional[str], info: FieldValidationInfo) -> password=info.data.get("postgres_pass"), host=info.data.get("postgres_host", ""), port=info.data.get("postgres_port", 5432), - path=f"/{info.data.get('postgres_dbname') or ''}", + path=info.data.get("postgres_dbname", ""), )