Skip to content

Commit

Permalink
Reset counter for failed connections (#385)
Browse files Browse the repository at this point in the history
Co-authored-by: Sergey Morozov <sergey@morozov.top>
  • Loading branch information
taybin and qweryty committed Sep 10, 2021
1 parent e3e7fa0 commit 6fcb168
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
8 changes: 6 additions & 2 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,12 @@ def __init__(self, backend: DatabaseBackend) -> None:
async def __aenter__(self) -> "Connection":
async with self._connection_lock:
self._connection_counter += 1
if self._connection_counter == 1:
await self._connection.acquire()
try:
if self._connection_counter == 1:
await self._connection.acquire()
except Exception as e:
self._connection_counter -= 1
raise e
return self

async def __aexit__(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import os
import re
from unittest.mock import patch, MagicMock

import pytest
import sqlalchemy
Expand All @@ -15,6 +16,11 @@
DATABASE_URLS = [url.strip() for url in os.environ["TEST_DATABASE_URLS"].split(",")]


class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)


class MyEpochType(sqlalchemy.types.TypeDecorator):
impl = sqlalchemy.Integer

Expand Down Expand Up @@ -267,6 +273,30 @@ async def test_ddl_queries(database_url):
await database.execute(query)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_queries_after_error(database_url):
"""
Test that the basic `execute()` works after a previous error.
"""

class DBException(Exception):
pass

async with Database(database_url) as database:
with patch.object(
database.connection()._connection,
"acquire",
new=AsyncMock(side_effect=DBException),
):
with pytest.raises(DBException):
query = notes.select()
await database.fetch_all(query)

query = notes.select()
await database.fetch_all(query)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_results_support_mapping_interface(database_url):
Expand Down

0 comments on commit 6fcb168

Please sign in to comment.