Skip to content

Commit

Permalink
Added iterate
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed Feb 25, 2018
1 parent cd4e7ac commit d38a7b9
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 60 deletions.
5 changes: 2 additions & 3 deletions examples/delegated.py
Expand Up @@ -30,9 +30,8 @@ async def main():
async with conn.transaction():
async for u in conn.iterate(User.query.where(User.id > 3)):
print(u)
await db.create_pool('postgresql://localhost/gino')
async with db.transaction() as (conn, tx):
async for u in conn.iterate(User.query.where(User.id > 3)):
async with db.transaction():
async for u in User.query.where(User.id > 3).gino.iterate():
print(u)


Expand Down
24 changes: 16 additions & 8 deletions gino/api.py
Expand Up @@ -6,7 +6,6 @@

from .crud import CRUDModel
from .declarative import declarative_base
from .dialects.asyncpg import GinoCursorFactory
from . import json_support


Expand Down Expand Up @@ -53,10 +52,13 @@ async def status(self, *multiparams, bind=None, **params):
return await bind.status(self._query, *multiparams, **params)

def iterate(self, *multiparams, connection=None, **params):
def env_factory():
conn = connection or self._query.bind
return conn, conn.metadata
return GinoCursorFactory(env_factory, self._query, multiparams, params)
if connection is None:
if self._query.bind:
connection = self._query.bind.current_connection
if connection is None:
raise ValueError(
'No Connection in context, please provide one')
return connection.iterate(self._query, *multiparams, **params)


class Gino(sa.MetaData):
Expand Down Expand Up @@ -102,9 +104,15 @@ async def scalar(self, clause, *multiparams, **params):
async def status(self, clause, *multiparams, **params):
return await self.bind.status(clause, *multiparams, **params)

def iterate(self, clause, *multiparams, connection=None, **params):
return GinoCursorFactory(lambda: (connection or self.bind, self),
clause, multiparams, params)
def iterate(self, clause, *multiparams, **params):
connection = None
if self.bind:
connection = self.bind.current_connection
if connection is None:
raise ValueError(
'No Connection in context, please provide one')
return self.bind.current_connection.iterate(clause, *multiparams,
**params)

def acquire(self, *args, **kwargs):
return self.bind.acquire(*args, **kwargs)
Expand Down
74 changes: 30 additions & 44 deletions gino/dialects/asyncpg.py
Expand Up @@ -60,98 +60,79 @@ def process_rows(self, rows, return_model=True):
rv.append(obj)
return rv

async def prepare(self, named=True):
return await self.connection.prepare(self.statement, named)

def get_result_proxy(self):
return ResultProxy(self)


class GinoCursorFactory:
def __init__(self, env_factory, clause, multiparams, params):
self._env_factory = env_factory
self._context = None
self._timeout = None
self._clause = clause
self._multiparams = multiparams
self._params = params

@property
def timeout(self):
return self._timeout

async def get_cursor_factory(self):
connection, metadata = self._env_factory()
self._context = metadata.dialect.execution_ctx_cls.init_clause(
metadata.dialect, self._clause, self._multiparams, self._params,
connection)
if self._context.executemany:
raise ValueError('too many multiparams')
self._timeout = self._context.timeout
ps = await self._context.prepare()
return ps.cursor(*self._context.parameters[0], timeout=self._timeout)
class CursorFactory:
def __init__(self, context):
self._context = context

@property
def context(self):
return self._context

async def get_raw_cursor(self):
prepared = await self._context.cursor.prepare(self._context.statement)
return prepared.cursor(*self._context.parameters[0],
timeout=self._context.timeout)

def __aiter__(self):
return GinoCursorIterator(self)
return CursorIterator(self)

def __await__(self):
return GinoCursor(self).async_init().__await__()
return Cursor(self).async_init().__await__()


class GinoCursorIterator:
class CursorIterator:
def __init__(self, factory):
self._factory = factory
self._iterator = None

def __aiter__(self):
return self

async def __anext__(self):
if self._iterator is None:
factory = await self._factory.get_cursor_factory()
self._iterator = factory.__aiter__()
raw = await self._factory.get_raw_cursor()
self._iterator = raw.__aiter__()
row = await self._iterator.__anext__()
return self._factory.context.process_rows([row])[0]


class GinoCursor:
class Cursor:
def __init__(self, factory):
self._factory = factory
self._cursor = None

async def async_init(self):
factory = await self._factory.get_cursor_factory()
self._cursor = await factory
raw = await self._factory.get_raw_cursor()
self._cursor = await raw
return self

async def many(self, n, *, timeout=DEFAULT):
if timeout is DEFAULT:
timeout = self._factory.timeout
timeout = self._factory.context.timeout
rows = await self._cursor.fetch(n, timeout=timeout)
return self._factory.context.process_rows(rows)

async def next(self, *, timeout=DEFAULT):
if timeout is DEFAULT:
timeout = self._factory.timeout
timeout = self._factory.context.timeout
row = await self._cursor.fetchrow(timeout=timeout)
if not row:
return None
return self._factory.context.process_rows([row])[0]

def __getattr__(self, item):
return getattr(self._cursor, item)


class Cursor:
class DBAPICursor:
def __init__(self, apg_conn):
self._conn = apg_conn
self._stmt = None

def execute(self, statement, parameters):
pass

def executemany(self, statement, parameters):
pass

@property
def stmt_exclusive_section(self):
return getattr(self._conn, '_stmt_exclusive_section')
Expand Down Expand Up @@ -215,6 +196,11 @@ async def execute(self, one=False, return_model=True, status=False):
rv.append(item)
return rv

def iterate(self):
if self._context.executemany:
raise ValueError('too many multiparams')
return CursorFactory(self._context)


# noinspection PyAbstractClass
class AsyncpgDialect(PGDialect):
Expand All @@ -223,7 +209,7 @@ class AsyncpgDialect(PGDialect):
default_paramstyle = 'numeric'
statement_compiler = AsyncpgCompiler
execution_ctx_cls = AsyncpgExecutionContext
cursor_cls = Cursor
cursor_cls = DBAPICursor
dbapi_type_map = {
114: JSON(),
3802: JSONB(),
Expand Down
16 changes: 14 additions & 2 deletions gino/engine.py
Expand Up @@ -141,6 +141,13 @@ async def _acquire(self, timeout, reuse):
stack.append(rv)
return functools.partial(self._release, stack), rv

@property
def current_connection(self):
try:
return self._ctx.get()[-1]
except (LookupError, IndexError):
pass

async def _release(self, stack):
await self._dialect.release_conn(stack.pop().raw_connection)

Expand Down Expand Up @@ -280,8 +287,9 @@ async def scalar(self, clause, *multiparams, **params):
result = self._execute(clause, multiparams, params)
rv = await result.execute(one=True, return_model=False)
if rv:
rv = rv[0][0]
return rv
return rv[0][0]
else:
return None

async def status(self, clause, *multiparams, **params):
"""
Expand All @@ -292,3 +300,7 @@ async def status(self, clause, *multiparams, **params):

def transaction(self, *args, **kwargs):
return GinoTransaction(self, args, kwargs)

def iterate(self, clause, *multiparams, **params):
result = self._execute(clause, multiparams, params)
return result.iterate()
8 changes: 5 additions & 3 deletions tests/conftest.py
Expand Up @@ -8,12 +8,14 @@
import gino
from .models import db, DB_ARGS, ASYNCPG_URL

ECHO = False


@pytest.fixture(scope='module')
def sa_engine():
rv = sqlalchemy.create_engine(
'postgresql://{user}:{password}@{host}:{port}/{database}'.format(
**DB_ARGS))
**DB_ARGS), echo=ECHO)
db.create_all(rv)
yield rv
db.drop_all(rv)
Expand All @@ -22,7 +24,7 @@ def sa_engine():

@pytest.fixture
async def engine(sa_engine):
e = await gino.create_engine(ASYNCPG_URL)
e = await gino.create_engine(ASYNCPG_URL, echo=ECHO)
yield e
await e.close()
sa_engine.execute('DELETE FROM gino_users')
Expand All @@ -31,7 +33,7 @@ async def engine(sa_engine):
# noinspection PyUnusedLocal,PyShadowingNames
@pytest.fixture
async def bind(sa_engine):
rv = await db.create_engine(ASYNCPG_URL)
rv = await db.create_engine(ASYNCPG_URL, echo=ECHO)
yield rv
await db.dispose_engine()
sa_engine.execute('DELETE FROM gino_users')
Expand Down
6 changes: 6 additions & 0 deletions tests/test_basic.py
Expand Up @@ -227,3 +227,9 @@ async def test_too_many_engine_args():
import gino
with pytest.raises(TypeError):
await gino.create_engine(ASYNCPG_URL, non_exist=None)


# noinspection PyUnusedLocal
async def test_scalar_return_none(bind):
assert await User.query.where(
User.nickname == 'nonexist').gino.scalar() is None
80 changes: 80 additions & 0 deletions tests/test_iterate.py
@@ -0,0 +1,80 @@
import pytest

from .models import db, User

pytestmark = pytest.mark.asyncio


@pytest.fixture
async def names(sa_engine):
rv = {'11', '22', '33'}
sa_engine.execute(User.__table__.insert(),
[dict(nickname=name) for name in rv])
yield rv
sa_engine.execute('DELETE FROM gino_users')


# noinspection PyUnusedLocal,PyShadowingNames
async def test_bind(bind, names):
with pytest.raises(ValueError, match='No Connection in context'):
async for u in User.query.gino.iterate():
assert False, 'Should not reach here'
with pytest.raises(ValueError, match='No Connection in context'):
await User.query.gino.iterate()
with pytest.raises(ValueError, match='No Connection in context'):
await db.iterate(User.query)

result = set()
async with bind.transaction():
async for u in User.query.gino.iterate():
result.add(u.nickname)
assert names == result

result = set()
async with bind.transaction():
async for u in db.iterate(User.query):
result.add(u.nickname)
assert names == result

result = set()
async with bind.transaction():
cursor = await User.query.gino.iterate()
result.add((await cursor.next()).nickname)
assert names != result
result.update([u.nickname for u in await cursor.many(1)])
assert names != result
result.update([u.nickname for u in await cursor.many(2)])
assert names == result
result.update([u.nickname for u in await cursor.many(2)])
assert names == result
assert await cursor.next() is None

with pytest.raises(ValueError, match='too many multiparams'):
async with bind.transaction():
await db.iterate(User.insert().returning(User.nickname), [
dict(nickname='444'),
dict(nickname='555'),
dict(nickname='666'),
])


# noinspection PyUnusedLocal,PyShadowingNames
async def test_basic(engine, names):
result = set()
async with engine.transaction() as tx:
with pytest.raises(ValueError, match='No Connection in context'):
async for u in User.query.gino.iterate():
assert False, 'Should not reach here'
with pytest.raises(ValueError, match='No Connection in context'):
await db.iterate(User.query)
result = set()
async for u in User.query.gino.iterate(connection=tx.connection):
result.add(u.nickname)
assert names == result

result = set()
cursor = await User.query.gino.iterate(connection=tx.connection)
result.update([u.nickname for u in await cursor.many(2)])
assert names != result
result.update([u.nickname for u in await cursor.many(2)])
assert names == result

0 comments on commit d38a7b9

Please sign in to comment.