Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
344 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# noinspection PyPackageRequirements | ||
from starlette.applications import Starlette | ||
# noinspection PyPackageRequirements | ||
from starlette.types import Message, Receive, Scope, Send | ||
# noinspection PyPackageRequirements | ||
from starlette.exceptions import HTTPException | ||
# noinspection PyPackageRequirements | ||
from starlette import status | ||
from sqlalchemy.engine.url import URL | ||
|
||
from ..api import Gino as _Gino, GinoExecutor as _Executor | ||
from ..engine import GinoConnection as _Connection, GinoEngine as _Engine | ||
from ..strategies import GinoStrategy | ||
|
||
|
||
class StarletteModelMixin: | ||
@classmethod | ||
async def get_or_404(cls, *args, **kwargs): | ||
# noinspection PyUnresolvedReferences | ||
rv = await cls.get(*args, **kwargs) | ||
if rv is None: | ||
raise HTTPException(status.HTTP_404_NOT_FOUND, | ||
'{} is not found'.format(cls.__name__)) | ||
return rv | ||
|
||
|
||
# noinspection PyClassHasNoInit | ||
class GinoExecutor(_Executor): | ||
async def first_or_404(self, *args, **kwargs): | ||
rv = await self.first(*args, **kwargs) | ||
if rv is None: | ||
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data') | ||
return rv | ||
|
||
|
||
# noinspection PyClassHasNoInit | ||
class GinoConnection(_Connection): | ||
async def first_or_404(self, *args, **kwargs): | ||
rv = await self.first(*args, **kwargs) | ||
if rv is None: | ||
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data') | ||
return rv | ||
|
||
|
||
# noinspection PyClassHasNoInit | ||
class GinoEngine(_Engine): | ||
connection_cls = GinoConnection | ||
|
||
async def first_or_404(self, *args, **kwargs): | ||
rv = await self.first(*args, **kwargs) | ||
if rv is None: | ||
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data') | ||
return rv | ||
|
||
|
||
class StarletteStrategy(GinoStrategy): | ||
name = 'starlette' | ||
engine_cls = GinoEngine | ||
|
||
|
||
StarletteStrategy() | ||
|
||
|
||
class _Middleware: | ||
def __init__(self, app, db): | ||
self.app = app | ||
self.db = db | ||
|
||
async def __call__(self, scope: Scope, receive: Receive, | ||
send: Send) -> None: | ||
if (scope['type'] == 'http' and | ||
self.db.config['use_connection_for_request']): | ||
scope['connection'] = await self.db.acquire(lazy=True) | ||
await self.app(scope, receive, send) | ||
conn = scope.pop('connection', None) | ||
if conn is not None: | ||
await conn.release() | ||
return | ||
|
||
if scope['type'] == 'lifespan': | ||
async def receiver() -> Message: | ||
message = await receive() | ||
if message["type"] == "lifespan.startup": | ||
await self.db.set_bind( | ||
self.db.config['dsn'], | ||
echo=self.db.config['echo'], | ||
min_size=self.db.config['min_size'], | ||
max_size=self.db.config['max_size'], | ||
ssl=self.db.config['ssl'], | ||
**self.db.config['kwargs'], | ||
) | ||
elif message["type"] == "lifespan.shutdown": | ||
await self.db.pop_bind().close() | ||
return message | ||
await self.app(scope, receiver, send) | ||
return | ||
|
||
await self.app(scope, receive, send) | ||
|
||
|
||
class Gino(_Gino): | ||
"""Support Starlette server. | ||
The common usage looks like this:: | ||
from starlette.applications import Starlette | ||
from gino.ext.starlette import Gino | ||
app = Starlette() | ||
db = Gino(app, **kwargs) | ||
GINO adds a middleware to the Starlette app to setup and cleanup database | ||
according to the configurations that passed in the ``kwargs`` parameter. | ||
The config includes: | ||
* ``driver`` - the database driver, default is ``asyncpg``. | ||
* ``host`` - database server host, default is ``localhost``. | ||
* ``port`` - database server port, default is ``5432``. | ||
* ``user`` - database server user, default is ``postgres``. | ||
* ``password`` - database server password, default is empty. | ||
* ``database`` - database name, default is ``postgres``. | ||
* ``dsn`` - a SQLAlchemy database URL to create the engine, its existence | ||
will replace all previous connect arguments. | ||
* ``pool_min_size`` - the initial number of connections of the db pool. | ||
* ``pool_max_size`` - the maximum number of connections in the db pool. | ||
* ``echo`` - enable SQLAlchemy echo mode. | ||
* ``ssl`` - SSL context passed to ``asyncpg.connect``, default is ``None``. | ||
* ``use_connection_for_request`` - flag to set up lazy connection for | ||
requests. | ||
* ``kwargs`` - other parameters passed to the specified dialects, | ||
like ``asyncpg``. Unrecognized parameters will cause exceptions. | ||
If ``use_connection_for_request`` is set to be True, then a lazy connection | ||
is available at ``request['connection']``. By default, a database | ||
connection is borrowed on the first query, shared in the same execution | ||
context, and returned to the pool on response. If you need to release the | ||
connection early in the middle to do some long-running tasks, you can | ||
simply do this:: | ||
await request['connection'].release(permanent=False) | ||
""" | ||
model_base_classes = _Gino.model_base_classes + (StarletteModelMixin,) | ||
query_executor = GinoExecutor | ||
|
||
def __init__(self, app: Starlette, *args, **kwargs): | ||
self.config = dict() | ||
if 'dsn' in kwargs: | ||
self.config['dsn'] = kwargs.pop('dsn') | ||
else: | ||
self.config['dsn'] = URL( | ||
drivername=kwargs.pop('driver', 'asyncpg'), | ||
host=kwargs.pop('host', 'localhost'), | ||
port=kwargs.pop('port', 5432), | ||
username=kwargs.pop('user', 'postgres'), | ||
password=kwargs.pop('password', ''), | ||
database=kwargs.pop('database', 'postgres'), | ||
) | ||
self.config['echo'] = kwargs.pop('echo', False) | ||
self.config['min_size'] = kwargs.pop('pool_min_size', 5) | ||
self.config['max_size'] = kwargs.pop('pool_max_size', 10) | ||
self.config['ssl'] = kwargs.pop('ssl', None) | ||
self.config['use_connection_for_request'] = \ | ||
kwargs.pop('use_connection_for_request', True) | ||
self.config['kwargs'] = kwargs.pop('kwargs', dict()) | ||
|
||
super().__init__(*args, **kwargs) | ||
|
||
app.add_middleware(_Middleware, db=self) | ||
|
||
async def first_or_404(self, *args, **kwargs): | ||
rv = await self.first(*args, **kwargs) | ||
if rv is None: | ||
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data') | ||
return rv | ||
|
||
async def set_bind(self, bind, loop=None, **kwargs): | ||
kwargs.setdefault('strategy', 'starlette') | ||
return await super().set_bind(bind, loop=loop, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from async_generator import yield_, async_generator | ||
import pytest | ||
from starlette.applications import Starlette | ||
from starlette.responses import JSONResponse, PlainTextResponse | ||
from starlette.testclient import TestClient | ||
|
||
import gino | ||
from gino.ext.starlette import Gino | ||
|
||
from .models import DB_ARGS, PG_URL | ||
|
||
_MAX_INACTIVE_CONNECTION_LIFETIME = 59.0 | ||
|
||
|
||
# noinspection PyShadowingNames | ||
async def _app(**kwargs): | ||
app = Starlette() | ||
kwargs.update({ | ||
'kwargs': dict( | ||
max_inactive_connection_lifetime=_MAX_INACTIVE_CONNECTION_LIFETIME, | ||
), | ||
}) | ||
db = Gino(app, **kwargs) | ||
|
||
class User(db.Model): | ||
__tablename__ = 'gino_users' | ||
|
||
id = db.Column(db.BigInteger(), primary_key=True) | ||
nickname = db.Column(db.Unicode(), default='noname') | ||
|
||
@app.route('/') | ||
async def root(request): | ||
conn = await request['connection'].get_raw_connection() | ||
# noinspection PyProtectedMember | ||
assert conn._holder._max_inactive_time == \ | ||
_MAX_INACTIVE_CONNECTION_LIFETIME | ||
return PlainTextResponse('Hello, world!') | ||
|
||
@app.route('/users/{uid:int}') | ||
async def get_user(request): | ||
uid = request.path_params.get('uid') | ||
method = request.query_params.get('method') | ||
q = User.query.where(User.id == uid) | ||
if method == '1': | ||
return JSONResponse((await q.gino.first_or_404()).to_dict()) | ||
elif method == '2': | ||
return JSONResponse( | ||
(await request['connection'].first_or_404(q)).to_dict()) | ||
elif method == '3': | ||
return JSONResponse( | ||
(await db.bind.first_or_404(q)).to_dict()) | ||
elif method == '4': | ||
return JSONResponse( | ||
(await db.first_or_404(q)).to_dict()) | ||
else: | ||
return JSONResponse((await User.get_or_404(uid)).to_dict()) | ||
|
||
@app.route('/users', methods=['POST']) | ||
async def add_user(request): | ||
u = await User.create(nickname=(await request.json()).get('name')) | ||
await u.query.gino.first_or_404() | ||
await db.first_or_404(u.query) | ||
await db.bind.first_or_404(u.query) | ||
await request['connection'].first_or_404(u.query) | ||
return JSONResponse(u.to_dict()) | ||
|
||
e = await gino.create_engine(PG_URL) | ||
try: | ||
try: | ||
await db.gino.create_all(e) | ||
await yield_(app) | ||
finally: | ||
await db.gino.drop_all(e) | ||
finally: | ||
await e.close() | ||
|
||
|
||
@pytest.fixture | ||
@async_generator | ||
async def app(): | ||
await _app( | ||
host=DB_ARGS['host'], | ||
port=DB_ARGS['port'], | ||
user=DB_ARGS['user'], | ||
password=DB_ARGS['password'], | ||
database=DB_ARGS['database'], | ||
) | ||
|
||
|
||
@pytest.fixture | ||
@async_generator | ||
async def app_ssl(ssl_ctx): | ||
await _app( | ||
host=DB_ARGS['host'], | ||
port=DB_ARGS['port'], | ||
user=DB_ARGS['user'], | ||
password=DB_ARGS['password'], | ||
database=DB_ARGS['database'], | ||
ssl=ssl_ctx, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
@async_generator | ||
async def app_dsn(): | ||
await _app(dsn=PG_URL) | ||
|
||
|
||
def _test_index_returns_200(app): | ||
client = TestClient(app) | ||
with client: | ||
response = client.get('/') | ||
assert response.status_code == 200 | ||
assert response.text == 'Hello, world!' | ||
|
||
|
||
def test_index_returns_200(app): | ||
_test_index_returns_200(app) | ||
|
||
|
||
def test_index_returns_200_dsn(app_dsn): | ||
_test_index_returns_200(app_dsn) | ||
|
||
|
||
def _test(app): | ||
client = TestClient(app) | ||
with client: | ||
for method in '01234': | ||
response = client.get('/users/1?method=' + method) | ||
assert response.status_code == 404 | ||
|
||
response = client.post('/users', json=dict(name='fantix')) | ||
assert response.status_code == 200 | ||
assert response.json() == dict(id=1, nickname='fantix') | ||
|
||
for method in '01234': | ||
response = client.get('/users/1?method=' + method) | ||
assert response.status_code == 200 | ||
assert response.json() == dict(id=1, nickname='fantix') | ||
|
||
|
||
def test(app): | ||
_test(app) | ||
|
||
|
||
def test_ssl(app_ssl): | ||
_test(app_ssl) | ||
|
||
|
||
def test_dsn(app_dsn): | ||
_test(app_dsn) |