Skip to content

Commit

Permalink
Expose query.options to SQLAlchemyManager from the Resource (#168)
Browse files Browse the repository at this point in the history
Useful to perform some prefetching of your relations in order to avoid
the n+1 queries problem.

```python
class TypeResource(ModelResource):
    class Meta:
        model = Type
        include_type = True
        query_options = [joinedload(Type.machines)]

    class Schema:
        machines = fields.ToMany('machine')
```
  • Loading branch information
ticosax authored and lyschoening committed Jan 17, 2019
1 parent 1ae0208 commit f804eb0
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 4 deletions.
7 changes: 6 additions & 1 deletion flask_potion/contrib/alchemy/manager.py
Expand Up @@ -151,7 +151,12 @@ def _is_change(a, b):
return (a is None) != (b is None) or a != b

def _query(self):
return self.model.query
query = self.model.query
try:
query_options = self.resource.meta.query_options
except KeyError:
return query
return query.options(*query_options)

def _query_filter(self, query, expression):
return query.filter(expression)
Expand Down
52 changes: 51 additions & 1 deletion tests/__init__.py
@@ -1,6 +1,9 @@
from pprint import pformat

from flask import json, Flask
from flask.testing import FlaskClient
from flask_testing import TestCase
import sqlalchemy


class ApiClient(FlaskClient):
Expand Down Expand Up @@ -49,4 +52,51 @@ def create_app(self):
return app

def pp(self, obj):
print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': ')))
print(json.dumps(obj, sort_keys=True, indent=4, separators=(',', ': ')))


class DBQueryCounter:
"""
Use as a context manager to count the number of execute()'s performed
against the given sqlalchemy connection.
Usage:
with DBQueryCounter(db.session) as ctr:
db.session.execute("SELECT 1")
db.session.execute("SELECT 1")
ctr.assert_count(2)
"""

def __init__(self, session, reset=True):
self.session = session
self.reset = reset
self.statements = []

def __enter__(self):
if self.reset:
self.session.expire_all()
sqlalchemy.event.listen(
self.session.get_bind(), 'after_execute', self._callback
)
return self

def __exit__(self, *_):
sqlalchemy.event.remove(
self.session.get_bind(), 'after_execute', self._callback
)

def get_count(self):
return len(self.statements)

def _callback(self, conn, clause_element, multiparams, params, result):
self.statements.append((clause_element, multiparams, params))

def display_all(self):
for clause, multiparams, params in self.statements:
print(pformat(str(clause)), multiparams, params)
print('\n')
count = self.get_count()
return 'Counted: {count}'.format(count=count)

def assert_count(self, expected):
count = self.get_count()
assert count == expected, self.display_all()
78 changes: 76 additions & 2 deletions tests/contrib/alchemy/test_manager_sqlalchemy.py
@@ -1,11 +1,11 @@
import unittest
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import backref
from sqlalchemy.orm import backref, joinedload
from flask_potion.routes import Relation
from flask_potion.contrib.alchemy import SQLAlchemyManager
from flask_potion import Api, fields
from flask_potion.resource import ModelResource
from tests import BaseTestCase
from tests import BaseTestCase, DBQueryCounter


class SQLAlchemyTestCase(BaseTestCase):
Expand Down Expand Up @@ -523,3 +523,77 @@ def test_sort_by_related_field(self):
self.assert200(response)
type_uris = [entry['type']['$ref'] for entry in response.json]
self.assertTrue(type_uris, [bbb_uri, aaa_uri])


class QueryOptionsSQLAlchemyTestCase(BaseTestCase):
def setUp(self):
super(QueryOptionsSQLAlchemyTestCase, self).setUp()
self.app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
self.api = Api(self.app, default_manager=SQLAlchemyManager)
self.sa = sa = SQLAlchemy(self.app, session_options={"autoflush": False})

class Type(sa.Model):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(60), nullable=False, unique=True)
version = sa.Column(sa.Integer(), nullable=True)
machines = sa.relationship('Machine', back_populates='type')

class Machine(sa.Model):
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(60), nullable=False)

wattage = sa.Column(sa.Float)

type_id = sa.Column(sa.Integer, sa.ForeignKey(Type.id))
type = sa.relationship(Type, back_populates='machines')

sa.create_all()

class MachineResource(ModelResource):
class Meta:
model = Machine
include_type = True

class Schema:
type = fields.ToOne('type')

class TypeResource(ModelResource):
class Meta:
model = Type
include_type = True
query_options = [joinedload(Type.machines)]

class Schema:
machines = fields.ToMany('machine')

self.MachineResource = MachineResource
self.TypeResource = TypeResource

self.api.add_resource(MachineResource)
self.api.add_resource(TypeResource)

def tearDown(self):
self.sa.drop_all()

def test_get(self):
response = self.client.post('/type', data={"name": "aaa"})
self.assert200(response)
aaa_uri = response.json["$uri"]

response = self.client.post(
'/machine', data={"name": "foo", "type": {"$ref": aaa_uri}})
self.assert200(response)
machine_uri = response.json['$uri']

with DBQueryCounter(self.sa.session) as counter:
response = self.client.get(aaa_uri)
self.assert200(response)
self.assertJSONEqual(
response.json,
{'$type': 'type',
'$uri': aaa_uri,
'machines': [{'$ref': machine_uri}],
'name': 'aaa',
'version': None,
})
counter.assert_count(1)

0 comments on commit f804eb0

Please sign in to comment.