Skip to content

Commit

Permalink
Merge aa01c83 into b4ad66a
Browse files Browse the repository at this point in the history
  • Loading branch information
wwwjfy committed Jun 3, 2019
2 parents b4ad66a + aa01c83 commit 6b8ec97
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 3 deletions.
6 changes: 3 additions & 3 deletions gino/ext/aiohttp.py
Expand Up @@ -122,8 +122,8 @@ def init_app(self, app, config=None, *, db_attr_name='db'):
else:
config = config.copy()

async def before_server_start(app_):
if config.get('dsn'):
async def before_server_start(_):
if 'dsn' in config:
dsn = config['dsn']
else:
dsn = URL(
Expand All @@ -144,7 +144,7 @@ async def before_server_start(app_):
**config.setdefault('kwargs', dict()),
)

async def after_server_stop(app_):
async def after_server_stop(_):
await self.pop_bind().close()

app.on_startup.append(before_server_start)
Expand Down
2 changes: 2 additions & 0 deletions gino/ext/quart.py
@@ -1,6 +1,8 @@
import asyncio

# noinspection PyPackageRequirements
from quart import Quart, request
# noinspection PyPackageRequirements
from quart.exceptions import NotFound
from sqlalchemy.engine.url import URL

Expand Down
180 changes: 180 additions & 0 deletions gino/ext/starlette.py
@@ -0,0 +1,180 @@
# noinspection PyPackageRequirements
from starlette.applications import Starlette
# noinspection PyPackageRequirements
from starlette.types import Message, Receive, Scope, Send
# noinspection PyPackageRequirements
from starlette.exceptions import HTTPException
# noinspection PyPackageRequirements
from starlette import status
from sqlalchemy.engine.url import URL

from ..api import Gino as _Gino, GinoExecutor as _Executor
from ..engine import GinoConnection as _Connection, GinoEngine as _Engine
from ..strategies import GinoStrategy


class StarletteModelMixin:
@classmethod
async def get_or_404(cls, *args, **kwargs):
# noinspection PyUnresolvedReferences
rv = await cls.get(*args, **kwargs)
if rv is None:
raise HTTPException(status.HTTP_404_NOT_FOUND,
'{} is not found'.format(cls.__name__))
return rv


# noinspection PyClassHasNoInit
class GinoExecutor(_Executor):
async def first_or_404(self, *args, **kwargs):
rv = await self.first(*args, **kwargs)
if rv is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
return rv


# noinspection PyClassHasNoInit
class GinoConnection(_Connection):
async def first_or_404(self, *args, **kwargs):
rv = await self.first(*args, **kwargs)
if rv is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
return rv


# noinspection PyClassHasNoInit
class GinoEngine(_Engine):
connection_cls = GinoConnection

async def first_or_404(self, *args, **kwargs):
rv = await self.first(*args, **kwargs)
if rv is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
return rv


class StarletteStrategy(GinoStrategy):
name = 'starlette'
engine_cls = GinoEngine


StarletteStrategy()


class _Middleware:
def __init__(self, app, db):
self.app = app
self.db = db

async def __call__(self, scope: Scope, receive: Receive,
send: Send) -> None:
if (scope['type'] == 'http' and
self.db.config['use_connection_for_request']):
scope['connection'] = await self.db.acquire(lazy=True)
await self.app(scope, receive, send)
conn = scope.pop('connection', None)
if conn is not None:
await conn.release()
return

if scope['type'] == 'lifespan':
async def receiver() -> Message:
message = await receive()
if message["type"] == "lifespan.startup":
await self.db.set_bind(
self.db.config['dsn'],
echo=self.db.config['echo'],
min_size=self.db.config['min_size'],
max_size=self.db.config['max_size'],
ssl=self.db.config['ssl'],
**self.db.config['kwargs'],
)
elif message["type"] == "lifespan.shutdown":
await self.db.pop_bind().close()
return message
await self.app(scope, receiver, send)
return

await self.app(scope, receive, send)


class Gino(_Gino):
"""Support Starlette server.
The common usage looks like this::
from starlette.applications import Starlette
from gino.ext.starlette import Gino
app = Starlette()
db = Gino(app, **kwargs)
GINO adds a middleware to the Starlette app to setup and cleanup database
according to the configurations that passed in the ``kwargs`` parameter.
The config includes:
* ``driver`` - the database driver, default is ``asyncpg``.
* ``host`` - database server host, default is ``localhost``.
* ``port`` - database server port, default is ``5432``.
* ``user`` - database server user, default is ``postgres``.
* ``password`` - database server password, default is empty.
* ``database`` - database name, default is ``postgres``.
* ``dsn`` - a SQLAlchemy database URL to create the engine, its existence
will replace all previous connect arguments.
* ``pool_min_size`` - the initial number of connections of the db pool.
* ``pool_max_size`` - the maximum number of connections in the db pool.
* ``echo`` - enable SQLAlchemy echo mode.
* ``ssl`` - SSL context passed to ``asyncpg.connect``, default is ``None``.
* ``use_connection_for_request`` - flag to set up lazy connection for
requests.
* ``kwargs`` - other parameters passed to the specified dialects,
like ``asyncpg``. Unrecognized parameters will cause exceptions.
If ``use_connection_for_request`` is set to be True, then a lazy connection
is available at ``request['connection']``. By default, a database
connection is borrowed on the first query, shared in the same execution
context, and returned to the pool on response. If you need to release the
connection early in the middle to do some long-running tasks, you can
simply do this::
await request['connection'].release(permanent=False)
"""
model_base_classes = _Gino.model_base_classes + (StarletteModelMixin,)
query_executor = GinoExecutor

def __init__(self, app: Starlette, *args, **kwargs):
self.config = dict()
if 'dsn' in kwargs:
self.config['dsn'] = kwargs.pop('dsn')
else:
self.config['dsn'] = URL(
drivername=kwargs.pop('driver', 'asyncpg'),
host=kwargs.pop('host', 'localhost'),
port=kwargs.pop('port', 5432),
username=kwargs.pop('user', 'postgres'),
password=kwargs.pop('password', ''),
database=kwargs.pop('database', 'postgres'),
)
self.config['echo'] = kwargs.pop('echo', False)
self.config['min_size'] = kwargs.pop('pool_min_size', 5)
self.config['max_size'] = kwargs.pop('pool_max_size', 10)
self.config['ssl'] = kwargs.pop('ssl', None)
self.config['use_connection_for_request'] = \
kwargs.pop('use_connection_for_request', True)
self.config['kwargs'] = kwargs.pop('kwargs', dict())

super().__init__(*args, **kwargs)

app.add_middleware(_Middleware, db=self)

async def first_or_404(self, *args, **kwargs):
rv = await self.first(*args, **kwargs)
if rv is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, 'No such data')
return rv

async def set_bind(self, bind, loop=None, **kwargs):
kwargs.setdefault('strategy', 'starlette')
return await super().set_bind(bind, loop=loop, **kwargs)
1 change: 1 addition & 0 deletions requirements_dev.txt
Expand Up @@ -10,6 +10,7 @@ aiohttp==3.5.0 # pyup: update minor
tornado==6.0 # pyup: update minor
async_generator==1.10 # pyup: update minor
quart==0.9.1;python_version>="3.7" # pyup: update minor
starlette==0.12.0;python_version>="3.6" # pyup: update minor

# tests
coverage==4.5.1 # pyup: update minor
Expand Down
7 changes: 7 additions & 0 deletions tests/test_sanic.py
@@ -1,3 +1,5 @@
import asyncio

from async_generator import yield_, async_generator
import pytest
import sanic
Expand All @@ -11,6 +13,11 @@
_MAX_INACTIVE_CONNECTION_LIFETIME = 59.0


def teardown_module():
# sanic server will close the loop during shutdown
asyncio.set_event_loop(asyncio.new_event_loop())


# noinspection PyShadowingNames
async def _app(config):
app = sanic.Sanic()
Expand Down
151 changes: 151 additions & 0 deletions tests/test_starlette.py
@@ -0,0 +1,151 @@
from async_generator import yield_, async_generator
import pytest
from starlette.applications import Starlette
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.testclient import TestClient

import gino
from gino.ext.starlette import Gino

from .models import DB_ARGS, PG_URL

_MAX_INACTIVE_CONNECTION_LIFETIME = 59.0


# noinspection PyShadowingNames
async def _app(**kwargs):
app = Starlette()
kwargs.update({
'kwargs': dict(
max_inactive_connection_lifetime=_MAX_INACTIVE_CONNECTION_LIFETIME,
),
})
db = Gino(app, **kwargs)

class User(db.Model):
__tablename__ = 'gino_users'

id = db.Column(db.BigInteger(), primary_key=True)
nickname = db.Column(db.Unicode(), default='noname')

@app.route('/')
async def root(request):
conn = await request['connection'].get_raw_connection()
# noinspection PyProtectedMember
assert conn._holder._max_inactive_time == \
_MAX_INACTIVE_CONNECTION_LIFETIME
return PlainTextResponse('Hello, world!')

@app.route('/users/{uid:int}')
async def get_user(request):
uid = request.path_params.get('uid')
method = request.query_params.get('method')
q = User.query.where(User.id == uid)
if method == '1':
return JSONResponse((await q.gino.first_or_404()).to_dict())
elif method == '2':
return JSONResponse(
(await request['connection'].first_or_404(q)).to_dict())
elif method == '3':
return JSONResponse(
(await db.bind.first_or_404(q)).to_dict())
elif method == '4':
return JSONResponse(
(await db.first_or_404(q)).to_dict())
else:
return JSONResponse((await User.get_or_404(uid)).to_dict())

@app.route('/users', methods=['POST'])
async def add_user(request):
u = await User.create(nickname=(await request.json()).get('name'))
await u.query.gino.first_or_404()
await db.first_or_404(u.query)
await db.bind.first_or_404(u.query)
await request['connection'].first_or_404(u.query)
return JSONResponse(u.to_dict())

e = await gino.create_engine(PG_URL)
try:
try:
await db.gino.create_all(e)
await yield_(app)
finally:
await db.gino.drop_all(e)
finally:
await e.close()


@pytest.fixture
@async_generator
async def app():
await _app(
host=DB_ARGS['host'],
port=DB_ARGS['port'],
user=DB_ARGS['user'],
password=DB_ARGS['password'],
database=DB_ARGS['database'],
)


@pytest.fixture
@async_generator
async def app_ssl(ssl_ctx):
await _app(
host=DB_ARGS['host'],
port=DB_ARGS['port'],
user=DB_ARGS['user'],
password=DB_ARGS['password'],
database=DB_ARGS['database'],
ssl=ssl_ctx,
)


@pytest.fixture
@async_generator
async def app_dsn():
await _app(dsn=PG_URL)


def _test_index_returns_200(app):
client = TestClient(app)
with client:
response = client.get('/')
assert response.status_code == 200
assert response.text == 'Hello, world!'


def test_index_returns_200(app):
_test_index_returns_200(app)


def test_index_returns_200_dsn(app_dsn):
_test_index_returns_200(app_dsn)


def _test(app):
client = TestClient(app)
with client:
for method in '01234':
response = client.get('/users/1?method=' + method)
assert response.status_code == 404

response = client.post('/users', json=dict(name='fantix'))
assert response.status_code == 200
assert response.json() == dict(id=1, nickname='fantix')

for method in '01234':
response = client.get('/users/1?method=' + method)
assert response.status_code == 200
assert response.json() == dict(id=1, nickname='fantix')


def test(app):
_test(app)


def test_ssl(app_ssl):
_test(app_ssl)


def test_dsn(app_dsn):
_test(app_dsn)

0 comments on commit 6b8ec97

Please sign in to comment.