Skip to content

Commit

Permalink
fix: Ensure password is overrideable
Browse files Browse the repository at this point in the history
Removes the explicit naming of the password parameter for asyncpg, and
moves the setting of the password to the kwargs. Further adds tests to
ensure that the password argument is preserved.
  • Loading branch information
Nicolai Willems committed Oct 8, 2020
1 parent 932c5d1 commit d66c0c1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
5 changes: 4 additions & 1 deletion databases/backends/postgres.py
Expand Up @@ -43,6 +43,7 @@ def _get_dialect(self) -> Dialect:

def _get_connection_kwargs(self) -> dict:
url_options = self._database_url.options
password = self._database_url.password

kwargs = {}
min_size = url_options.get("min_size")
Expand All @@ -56,6 +57,9 @@ def _get_connection_kwargs(self) -> dict:
if ssl is not None:
kwargs["ssl"] = {"true": True, "false": False}[ssl.lower()]

if password is not None:
kwargs["password"] = password

kwargs.update(self._options)

return kwargs
Expand All @@ -67,7 +71,6 @@ async def connect(self) -> None:
host=self._database_url.hostname,
port=self._database_url.port,
user=self._database_url.username,
password=self._database_url.password,
database=self._database_url.database,
**kwargs,
)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_connection_options.py
Expand Up @@ -43,6 +43,30 @@ def test_postgres_explicit_ssl():
assert kwargs == {"ssl": True}


def test_postgres_no_extra_options():
backend = PostgresBackend("postgres://localhost/database")
kwargs = backend._get_connection_kwargs()
assert kwargs == {}


def test_postgres_password_in_kwargs():
backend = PostgresBackend("postgres://:password@localhost/database")
kwargs = backend._get_connection_kwargs()
assert kwargs == {"password": "password"}


def test_postgres_password_as_callable():
def gen_password():
return "Foo"

backend = PostgresBackend(
"postgres://:password@localhost/database", password=gen_password
)
kwargs = backend._get_connection_kwargs()
assert kwargs == {"password": gen_password}
assert kwargs["password"]() == "Foo"


def test_mysql_pool_size():
backend = MySQLBackend("mysql://localhost/database?min_size=1&max_size=20")
kwargs = backend._get_connection_kwargs()
Expand Down

0 comments on commit d66c0c1

Please sign in to comment.