-
|
Description Is it possible to disable or override middleware in unit tests? Additional context I have written some middleware to obtain a database pool connection and release it when the request is done. I'm using asyncpg directly for this as follows: @app.middleware("http")
async def database_middleware(request: Request, call_next):
conn = None
try:
conn = await database_pool.acquire()
try:
await conn.fetch("SELECT 1")
except ConnectionDoesNotExistError:
conn = await database_pool.acquire()
request.state.db = conn
return await call_next(request)
except (PostgresConnectionError, OSError) as e:
logger.error("Unable to connect to the database: %s", e)
return Response(
"Unable to connect to the database.", status_code=HTTP_500_INTERNAL_SERVER_ERROR
)
except SyntaxOrAccessError as e:
logger.error("Unable to execute query: %s", e)
return Response(
"Unable to execute the required query to obtain data from the database.",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)
finally:
if conn:
await database_pool.release(conn)I then have a little def get_db(request: Request):
return request.state.dbI'd like to mock my database in my unit tests so I can verify that I'm building the correct SQL given the appropriate params. I've setup a pytest fixture to mock the database as follows: @pytest.fixture(scope="function")
def db_mock(request):
def fin():
del app.dependency_overrides[get_db]
db_mock = mock.MagicMock()
app.dependency_overrides[get_db] = lambda: db_mock
request.addfinalizer(fin)
return db_mockThis works perfectly when the middleware is disabled. However, when it's enabled, naturally all endpoints fail with a 500 error as the database fails to connect. After looking a little deeper into starlette codebase and the way middleware works, I couldn't seem to find an elegant way to disable or override middleware in my unit tests. Any help is greatly appreciated. Also if you see anything wrong with the approach I'm taking, please do let me know. Huge thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 23 comments
-
|
I'd try to use asynctest CoroutineMock instead of MagicMock |
Beta Was this translation helpful? Give feedback.
-
Oh interesting idea, but yeah, pytest injection doesn't use async at this stage. I'm injecting this fixture as follows: def test_events(client, db_mock, current_user_mock):
db_mock.fetch.return_value = Future()
db_mock.fetch.return_value.set_result([])
response = client.post("/v1/events/", json={"id": ["bla"]})
assert response.status_code == 200
assert response.json() == []
assert db_mock.fetch.call_args == mock.call(
"SELECT * "
"FROM events"
"WHERE id = ANY ($1) "
"ORDER BY event_timestamp DESC "
"LIMIT $2",
["bla"], 20
)Of course, I'm totally open to better ways of accomplishing this! 😄 I'm definitely a bit new to async in Python and am coming from Falcon (which is not async at all). |
Beta Was this translation helpful? Give feedback.
-
|
Actually @euri10, you're right, I can use the @pytest.fixture(scope="function")
def db_mock(request):
def fin():
del app.dependency_overrides[get_db]
db_mock = asynctest.Mock()
app.dependency_overrides[get_db] = lambda: db_mock
request.addfinalizer(fin)
return db_mock
def test_events(client, db_mock, current_user_mock):
db_mock.fetch = asynctest.CoroutineMock(return_value=[])
response = client.post("/v1/events/", json={"id": ["bla"]})
assert response.status_code == 200
assert response.json() == []
assert db_mock.fetch.call_args == mock.call(
"SELECT * "
"FROM events "
"WHERE id = ANY ($1) "
"ORDER BY event_timestamp DESC "
"LIMIT $2",
["bla"], 20
)The reason this works is that the endpoint itself uses |
Beta Was this translation helpful? Give feedback.
-
|
you can even use pytest.asyncio and have |
Beta Was this translation helpful? Give feedback.
-
|
So far, the only idea I can come up with is writing a def create_app(environment: str = os.getenv("ENVIRONMENT")):
app = FastAPI()
# Don't register events and middleware relating to the database while testing.
# Event handlers don't seem to run during testing anyway but we'll do this to be sure.
if environment != "testing":
app.add_event_handler("startup", startup)
app.add_event_handler("shutdown", shutdown)
app.add_middleware(BaseHTTPMiddleware, database_middleware)
api_router = APIRouter()
api_router.include_router(events.router, prefix="/events", tags=["events"])
# ...
app.include_router(api_router, prefix="/v1")More suggestions or ideas welcome though 😄 |
Beta Was this translation helpful? Give feedback.
-
|
@fgimian that’s exactly what I do |
Beta Was this translation helpful? Give feedback.
-
|
just to understand your setup a little bit better and eventually come up with a "better" mock @fgimian If I understand correctly from your post, the goal is to be sure a db_mock is passed the same sql query the enpoint once triggered is supposed to generate, this without having to pass through the db_middleware, is that correct ? I think it's a very interesting topic, I'm not using directly asyncpg and deal with db connection with startup and shutdown lifespan events using the |
Beta Was this translation helpful? Give feedback.
-
|
I'll definitely try out the databases package! I honestly thought it was a local import when I saw the code and clearly didn't read that part of the guide well enough. My current approach is not perfect, but goes something like this: # main.py
app = FastAPI()
logger = logging.getLogger(__name__)
database_pool = None
@app.on_event("startup")
async def startup():
# TODO: Is there a better way to create the pool without using a global?
global database_pool # pylint: disable=global-statement
database_pool = await asyncpg.create_pool(
host=config.DATABASE_HOST,
port=config.DATABASE_PORT,
user=config.DATABASE_USERNAME,
password=config.DATABASE_PASSWORD,
database=config.DATABASE_DBNAME,
min_size=0, # don't create any connections upon start-up
max_size=config.DATABASE_MAX_CONNECTIONS,
max_inactive_connection_lifetime=0, # never close connections after they're established
)
@app.on_event("shutdown")
async def shutdown():
await database_pool.close()
@app.middleware("http")
async def database_middleware(request: Request, call_next):
# TODO: Improve the mechanism used to determine whether we are in a unit test or not
if request.headers["user-agent"] == "testclient":
return await call_next(request)
conn = None
try:
conn = await database_pool.acquire()
try:
await conn.fetch("SELECT 1")
except ConnectionDoesNotExistError:
conn = await database_pool.acquire()
request.state.db = conn
return await call_next(request)
except (PostgresConnectionError, OSError) as e:
logger.error("Unable to connect to the database: %s", e)
return Response(
"Unable to connect to the database.", status_code=HTTP_500_INTERNAL_SERVER_ERROR
)
except SyntaxOrAccessError as e:
logger.error("Unable to execute query: %s", e)
return Response(
"Unable to execute the required query to obtain data from the database.",
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
)
finally:
if conn:
await database_pool.release(conn)
# database.py
def get_db(request: Request):
return request.state.dbHope this helps. As per the TODOs, there are still some things to improve if possible. The important thing with the database is that I don't want a connection to be made when the app starts up because that would mean the app could fail to start if the database was down at that moment. My preference is that the initial connections are established by the first few web requests and then are kept open forever. I'm typically using a pool size of around 20 per process. Interestingly, it seems that uvicorn doesn't allow starting an app from a Cheers |
Beta Was this translation helpful? Give feedback.
-
|
@fgimian There are many ways to make sure your app waits until the database/other necessary resources are ready. Some examples:
Click to expand (adapts your startup function to use tenacity)import logging
from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
max_tries = 60 * 5 # 5 minutes
wait_seconds = 3
@retry(
stop=stop_after_attempt(max_tries),
wait=wait_fixed(wait_seconds),
before=before_log(logger, logging.INFO),
after=after_log(logger, logging.WARN),
)
async def wait_for_db(database_pool) -> None:
conn = await database_pool.acquire()
await conn.fetch("SELECT 1")
@app.on_event("startup")
async def startup():
# TODO: Is there a better way to create the pool without using a global?
global database_pool # pylint: disable=global-statement
database_pool = await asyncpg.create_pool(
host=config.DATABASE_HOST,
port=config.DATABASE_PORT,
user=config.DATABASE_USERNAME,
password=config.DATABASE_PASSWORD,
database=config.DATABASE_DBNAME,
min_size=0, # don't create any connections upon start-up
max_size=config.DATABASE_MAX_CONNECTIONS,
max_inactive_connection_lifetime=0, # never close connections after they're established
)
await wait_for_db(database_pool)This way you wouldn't need to waste a db round trip at the start of every request. Also, +1 for databases. |
Beta Was this translation helpful? Give feedback.
-
|
@dmontagu Thank you so much, this is an awesome idea. The use of The reason it is needed is that database connections in a pool can go stale after initially established (due to network interruptions or db restarts). So you can't be sure that We had this issue in production too and the idea was to add a small |
Beta Was this translation helpful? Give feedback.
-
|
@fgimian Those are some good points; thanks for sharing the sqlalchemy docs page, I wasn't aware of that and it was good reading. In light of this consideration, I think I personally would be inclined to make use of an approach that just periodically polls the connection pool for stale connections, rather than going full "pessimistic", but I guess for most applications the overhead would probably be insignificant. |
Beta Was this translation helpful? Give feedback.
-
|
I'd like to share my solution to this problem. Firstly, I was never fond of using middleware of opening database connections. This implies that every single endpoint will open a database connection whether it is needed or not (e.g. in our case, the auth endpoint doesn't use the database as it talks to LDAP). I looked into the The solution goes something like this:
So here's my database.py: import logging
import asyncpg
from asyncpg import ConnectionDoesNotExistError, PostgresConnectionError, SyntaxOrAccessError
from fastapi import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from .settings import get_settings
pool = None
logger = logging.getLogger(__name__)
async def create_pool() -> None:
global pool # pylint: disable=global-statement
settings = get_settings()
pool = await asyncpg.create_pool(
host=settings.database.host,
port=settings.database.port,
user=settings.database.username,
password=settings.database.password,
database=settings.database.db_name,
min_size=0, # don't create any connections upon start-up
max_size=settings.database.max_connections,
max_inactive_connection_lifetime=0, # never close connections after they're established
)
async def close_pool() -> None:
await pool.close()
async def get_db(request: Request) -> asyncpg.connection.Connection:
"""Obtain a database connection from the pool."""
try:
conn = await pool.acquire()
# Test that the connection is still active by running a trivial query
# (https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic)
try:
await conn.execute("SELECT 1")
except ConnectionDoesNotExistError:
conn = await pool.acquire()
request.state.db = conn
return conn
except (PostgresConnectionError, OSError) as e:
logger.error("Unable to connect to the database: %s", e)
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail="Unable to connect to the database."
)
async def middleware(request: Request, call_next):
"""Ensures that any open database connection is closed after each request."""
try:
return await call_next(request)
except SyntaxOrAccessError as e:
logger.error("Unable to execute query: %s", e)
return JSONResponse(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Unable to execute the required query to obtain data from the database."
},
)
finally:
if hasattr(request.state, "db"):
await pool.release(request.state.db)I'm still not totally fond of using a global, but it seems to be the only way at present. I attempted to wrap this in a class and call an async method when injecting the dependency, but FastAPI didn't seem to know how to deal with that (it didn't execute the injection as a coroutine when it was a class method). The great part here is that now, a connection is not established while running unit tests as I have a mock version of @pytest.fixture
def client():
return TestClient(app)
@pytest.fixture(scope="function")
def db_mock(request):
def fin():
del app.dependency_overrides[get_db]
db = asynctest.Mock()
app.dependency_overrides[get_db] = lambda: db
request.addfinalizer(fin)
return dbI will now be attempting to write unit tests for my database module. 😄 Of course, I'm totally open to critique on my solution and welcome more ideas. Cheers |
Beta Was this translation helpful? Give feedback.
-
|
I played a little bit with it and like it, no need for a db anymore !
The only thing I find, that can potentially become very boring down the
road is that, but maybe it is me who wrote it weirdly, you will have to
write in each and every test something like mocked_db.fetch =
CoroutineMock(return_value=[]) or execute, etc depending on the method(s)
used in the endpoint.
https://gitlab.com/euri10/rapidfastapitest/blob/master/582_mock_middleware.py#L105
So there might exist a nicer fixture that provides those. Couldn't find a
nice way yet
Le sam. 5 oct. 2019 à 8:56 AM, Fotis Gimian <notifications@github.com> a
écrit :
… I'd like to share my solution to this problem. Firstly, I was never fond
of using middleware of opening database connections. This implies that
every single endpoint will open a database connection whether it is needed
or not (e.g. in our case, the auth endpoint doesn't use the database as it
talks to LDAP).
I looked into the databases library which does look great, but offers no
advantages I really need over asyncpg. So the following solution still
uses asyncpg.
The solution goes something like this:
- Inject a dependency when the database is needed in an endpoint
- This injected function will obtain a connection from the database
pool and hand it to the endpoint while also saving the conection in
request.state
- A small middleware function will check whether or not a connection
has been set in request.state after an endpoint completes processing
and will close the connection if needed
So here's my *database.py*:
import logging
import asyncpg
from asyncpg import ConnectionDoesNotExistError, PostgresConnectionError, SyntaxOrAccessError
from fastapi import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from . import config
pool = None
logger = logging.getLogger(__name__)
async def create_pool() -> None:
global pool
pool = await asyncpg.create_pool(
host=config.DATABASE_HOST,
port=config.DATABASE_PORT,
user=config.DATABASE_USERNAME,
password=config.DATABASE_PASSWORD,
database=config.DATABASE_DBNAME,
min_size=0, # don't create any connections upon start-up
max_size=config.DATABASE_MAX_CONNECTIONS,
max_inactive_connection_lifetime=0, # never close connections after they're established
)
async def close_pool() -> None:
await pool.close()
async def get_db(request: Request) -> asyncpg.connection.Connection:
"""Obtain a database connection from the pool."""
try:
conn = await pool.acquire()
# Test that the connection is still active by running a trivial query
# (https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic)
try:
await conn.fetch("SELECT 1")
except ConnectionDoesNotExistError:
conn = await pool.acquire()
request.state.db = conn
return conn
except (PostgresConnectionError, OSError) as e:
logger.error("Unable to connect to the database: %s", e)
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="Unable to connect to the database.",
)
except SyntaxOrAccessError as e:
logger.error("Unable to execute query: %s", e)
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="Unable to execute the required query to obtain data from the database.",
)
async def middleware(request: Request, call_next):
"""Ensures that any open database connection is closed after each request."""
try:
return await call_next(request)
finally:
if hasattr(request.state, "db"):
await pool.release(request.state.db)
I'm still not totally fond of using a global, but it seems to be the only
way at present. I attempted to wrap this in a class and call an async
method when injecting the dependency, but FastAPI didn't seem to know how
to deal with that (it didn't execute the injection as a coroutine when it
was a class method).
The great part here is that now, a connection is not established while
running unit tests as I have a mock version of get_db:
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture(scope="function")
def db_mock(request):
def fin():
del app.dependency_overrides[get_db]
db = asynctest.Mock()
app.dependency_overrides[get_db] = lambda: db
request.addfinalizer(fin)
return db
I will now be attempting to write unit tests for my database module. 😄
Of course, I'm totally open to critique on my solution and welcome more
ideas.
Cheers
Fotis
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#582>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAINSPTTBO4GBPI67LIHMM3QNA3DJANCNFSM4I4IN2NA>
.
|
Beta Was this translation helpful? Give feedback.
-
|
Hey @euri10, yep that's exactly what I'm doing. def test_bla(client, db_mock, current_user_mock):
db_mock.fetch = asynctest.CoroutineMock(
return_value=[
Record([("abc", "def")]),
...
]The only little caviet was that asyncpg's Record class can't be instantiated. So I had to create a replica myself which looks like this: from collections import OrderedDict
class Record(OrderedDict):
"""
Provides a very similar Record class to that returned by asyncpg. A custom implementation
is needed as it is currently impossible to create asyncpg Record objects from Python code.
"""
def __getitem__(self, key_or_index):
if isinstance(key_or_index, int):
return list(self.values())[key_or_index]
return super().__getitem__(key_or_index)
def __repr__(self):
return "<{class_name} {items}>".format(
class_name=self.__class__.__name__,
items=" ".join(f"{k}={v!r}" for k, v in self.items()),
)This has worked perfectly for all my unit tests, I now have 100% coverage on all my endpoints. The challenge is writing unit tests for the |
Beta Was this translation helpful? Give feedback.
-
|
I figured you can declare execute, fetch, fetchrow etc as mock coroutines
in the db mock fixture, but the expected record is rd is still missing,
there might be a way
I'll play a little more with the record mock you wrote, thanks!
Le sam. 5 oct. 2019 à 1:33 PM, Fotis Gimian <notifications@github.com> a
écrit :
… Hey @euri10 <https://github.com/euri10>, yep that's exactly what I'm
doing.
def test_bla(client, db_mock, current_user_mock):
db_mock.fetch = asynctest.CoroutineMock(
return_value=[
Record([("abc", "def")]),
...
]
The only little caviet was that asyncpg's Record class can't be
instantiated. So I had to create a replica myself which looks like this:
from collections import OrderedDict
class Record(OrderedDict):
""" Provides a very similar Record class to that returned by asyncpg. A custom implementation is needed as it is currently impossible to create asyncpg Record objects from Python code. """
def __getitem__(self, key_or_index):
if isinstance(key_or_index, int):
return list(self.values())[key_or_index]
return super().__getitem__(key_or_index)
def __repr__(self):
return "<{class_name} {items}>".format(
class_name=self.__class__.__name__,
items=" ".join(f"{k}={v!r}" for k, v in self.items()),
)
This has worked perfectly for all my unit tests, I now have 100% coverage
on all my endpoints. The challenge is writing unit tests for the
database.py file.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#582>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAINSPXFPIEXVFJ7PP7XXVDQNB3P5ANCNFSM4I4IN2NA>
.
|
Beta Was this translation helpful? Give feedback.
-
|
No worries at all @euri10, really appreciate this discussion and how helpful you and @dmontagu have been here. It's so nice to see such a friendly community around FastAPI! 😄 I'll likely write a blog post or two after I finish this API implementation with various learnings to share. Cheers |
Beta Was this translation helpful? Give feedback.
-
|
In case anyone is interested in using my implementation, here are the related unit tests: tests/utils.py from unittest.mock import MagicMock
class AsyncMagicMock(MagicMock):
"""Implements a MagicMock class which return async methods."""
async def __call__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
return super().__call__(*args, **kwargs)tests/conftest.py import pytest
from asyncpg.pool import PoolConnectionProxy
from fastapi import Depends, FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from myapp import database
from myapp.database import get_db
@pytest.fixture
def app():
app = FastAPI()
app.add_event_handler("startup", database.create_pool)
app.add_event_handler("shutdown", database.close_pool)
app.add_middleware(BaseHTTPMiddleware, dispatch=database.middleware)
@app.get("/")
async def root(db: PoolConnectionProxy = Depends(get_db)):
return await db.fetch("SELECT * FROM data")
return apptests/test_database.py from unittest import mock
from asyncpg import ConnectionDoesNotExistError, SyntaxOrAccessError
from starlette.testclient import TestClient
from .utils import AsyncMagicMock
@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_successful_query(create_pool_mock, app):
db_mock = create_pool_mock.return_value.acquire.return_value
db_mock.fetch.return_value = []
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200
assert create_pool_mock.call_args == mock.call(
host="localhost",
port=5432,
user="user",
password="secret123",
database="myapp",
min_size=0,
max_size=20,
max_inactive_connection_lifetime=0,
)
assert create_pool_mock.return_value.acquire.call_count == 1
assert db_mock.execute.call_args == mock.call("SELECT 1")
assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
assert create_pool_mock.return_value.release.call_count == 1
@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_reestablish_connection(create_pool_mock, app):
db_mock = create_pool_mock.return_value.acquire.return_value
db_mock.execute.side_effect = ConnectionDoesNotExistError
db_mock.fetch.return_value = []
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 200
assert create_pool_mock.called
assert create_pool_mock.return_value.acquire.call_count == 2
assert db_mock.execute.call_args == mock.call("SELECT 1")
assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
assert create_pool_mock.return_value.release.call_count == 1
@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_failed_connection(create_pool_mock, app):
create_pool_mock.return_value.acquire.side_effect = ConnectionRefusedError
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 500
assert response.headers["content-type"] == "application/json"
assert response.json() == {"detail": "Unable to connect to the database."}
assert create_pool_mock.called
assert create_pool_mock.return_value.acquire.call_count == 1
assert not create_pool_mock.return_value.acquire.return_value.fetch.called
assert not create_pool_mock.return_value.release.called
@mock.patch("myapp.database.asyncpg.create_pool", new_callable=AsyncMagicMock)
def test_database_failed_query(create_pool_mock, app):
db_mock = create_pool_mock.return_value.acquire.return_value
db_mock.fetch.side_effect = SyntaxOrAccessError
with TestClient(app) as client:
response = client.get("/")
assert response.status_code == 500
assert response.headers["content-type"] == "application/json"
assert response.json() == {
"detail": "Unable to execute the required query to obtain data from the database."
}
assert create_pool_mock.called
assert create_pool_mock.return_value.acquire.call_count == 1
assert db_mock.fetch.call_args == mock.call("SELECT * FROM data")
assert create_pool_mock.return_value.release.call_count == 1 |
Beta Was this translation helpful? Give feedback.
-
|
I'm happy to close this issue now if there are no further comments. Please let me know your preference. 😄 Kindest regards |
Beta Was this translation helpful? Give feedback.
-
|
As there are no further follow-up comments, I'll close this ticket. Just wanted to thank everyone greatly for their input and help here! 😄 Cheers |
Beta Was this translation helpful? Give feedback.
-
|
Just wanted to say a huge thanks to the FastAPI team for this new feature https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/ which completely solves this dilemma. Absolutely beautiful! 😍 Thanks again |
Beta Was this translation helpful? Give feedback.
-
|
Thanks for the help here everyone! 👏 🙇 Thanks for reporting back and closing the issue @fgimian 👍 I'm glad you're liking FastAPI! 😄 |
Beta Was this translation helpful? Give feedback.
-
Hi, Can you give a final best approach for having global database access with yield? |
Beta Was this translation helpful? Give feedback.
-
|
With Python3.8 this is easily remedied like so: from unittest.mock import patch, AsyncMock
@pytest.fixture(autouse=True)
def mocked_asyncpg():
with patch("circular_api.utils.database.create_pool", new_callable=AsyncMock) as mocked_pool:
mocked = mocked_pool.return_value.acquire.return_value
yield mocked |
Beta Was this translation helpful? Give feedback.
So far, the only idea I can come up with is writing a
create_appfunction which selectively adds everything: