Skip to content

Commit

Permalink
Fix select_related empty result handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kvesteri committed Mar 16, 2017
1 parent 7ed0d9c commit c9b988e
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 22 deletions.
48 changes: 31 additions & 17 deletions sqlalchemy_json_api/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,22 @@ def _select_related(self, obj, relationship_key, **kwargs):
from_obj = kwargs.pop('from_obj', None)
if from_obj is None:
from_obj = sa.orm.query.Query(model)
from_obj = from_obj.filter(prop._with_parent(obj))

# SQLAlchemy Query.with_parent throws warning if the primary object
# foreign key is NULL. Thus we need this ugly magic to return empty
# data in that scenario.
if (
prop.direction.name == 'MANYTOONE' and
not prop.secondary and
getattr(obj, prop.local_remote_pairs[0][0].key) is None
):
expr = sa.cast({'data': None}, JSONB)
if kwargs.get('as_text'):
expr = sa.cast(expr, sa.Text)
return sa.select([expr])

from_obj = from_obj.with_parent(obj, prop)

if prop.order_by:
from_obj = from_obj.order_by(*prop.order_by)

Expand Down Expand Up @@ -267,7 +282,18 @@ def select(self, model, **kwargs):
if from_obj is None:
from_obj = sa.orm.query.Query(model)

from_obj = from_obj.subquery()
if kwargs.get('sort') is not None:
from_obj = apply_sort(
from_obj.statement,
from_obj,
kwargs.get('sort')
)
if kwargs.get('limit') is not None:
from_obj = from_obj.limit(kwargs.get('limit'))
if kwargs.get('offset') is not None:
from_obj = from_obj.offset(kwargs.get('offset'))

from_obj = from_obj.cte('main_query')

return SelectExpression(self, model, from_obj).build_select(**kwargs)

Expand Down Expand Up @@ -404,17 +430,11 @@ def _get_from_args(
from_args = [data_query.as_scalar().label('data')]

if params.include:
selectable = get_selectable(self.from_obj).original
if params.sort is not None:
selectable = apply_sort(selectable, selectable, params.sort)
if params.limit is not None:
selectable = selectable.limit(params.limit)
if params.offset is not None:
selectable = selectable.offset(params.offset)
selectable = self.from_obj
include_expr = IncludeExpression(
self.query_builder,
self.model,
selectable.alias()
selectable
)
included_query = include_expr.build_included(params)
from_args.append(included_query.as_scalar().label('included'))
Expand Down Expand Up @@ -717,12 +737,6 @@ def build_data_expr(self, params, ids_only=False):
def build_data(self, params, ids_only=False):
expr = self.build_data_expr(params, ids_only=ids_only)
query = sa.select([expr], from_obj=self.from_obj)
if params.sort is not None:
query = apply_sort(self.from_obj, query, params.sort)
if params.limit is not None:
query = query.limit(params.limit)
if params.offset is not None:
query = query.offset(params.offset)
return query

def build_data_array(self, params, ids_only=False):
Expand Down Expand Up @@ -797,7 +811,7 @@ def build_single_included(self, fields, path):
expr,
path,
alias,
get_selectable(self.from_obj),
self.from_obj,
correlate=False
).distinct()
if cls is self.model:
Expand Down
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import pytest
import sqlalchemy as sa
from sqlalchemy import create_engine
Expand All @@ -7,6 +9,8 @@

from sqlalchemy_json_api import QueryBuilder

warnings.filterwarnings('error')


@pytest.fixture(scope='class')
def base():
Expand Down Expand Up @@ -231,7 +235,7 @@ class CompositePKModel(base):

@pytest.fixture(scope='class')
def dns():
return 'postgres://postgres@localhost/sqlalchemy_json_api_test'
return 'postgresql://postgres@localhost/sqlalchemy_json_api_test'


@pytest.yield_fixture(scope='class')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_hybrid_property_inclusion_uses_clause_adaptation(
from_obj=session.query(article_cls)
)
compiled = query.compile(dialect=sa.dialects.postgresql.dialect())
assert 'upper(anon_2.name)' in str(compiled)
assert 'upper(main_query.name)' in str(compiled)

@pytest.mark.parametrize(
('fields', 'include', 'result'),
Expand Down
29 changes: 28 additions & 1 deletion tests/test_select_related.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,33 @@ def test_empty_result(
query = query_builder.select_related(
session.query(category_cls).get(id),
'parent',
fields={'categories': []}
fields={'categories': []},
session=session
)
assert session.execute(query).scalar() == result

@pytest.mark.parametrize(
('id', 'result'),
(
(
1,
{'data': None}
),
)
)
def test_empty_result_as_text(
self,
query_builder,
session,
category_cls,
id,
result
):
query = query_builder.select_related(
session.query(category_cls).get(id),
'parent',
fields={'categories': []},
session=session,
as_text=True
)
assert json.loads(session.execute(query).scalar()) == result
33 changes: 31 additions & 2 deletions tests/test_select_with_include.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,8 +668,37 @@ def test_hybrid_property_in_included_object(
category_cls.id == 1
)
)
from pprint import pprint
pprint(session.execute(query).scalar())
assert session.execute(query).scalar() == {
'included': [
{
'attributes': {'comment_count': 4},
'type': 'articles',
'id': '1'
}
],
'data': [
{
'relationships': {
'articles': {
'data': [{'type': 'articles', 'id': '1'}]
},
'subcategories': {
'data': [
{'type': 'categories', 'id': '2'},
{'type': 'categories', 'id': '4'}
]
},
'parent': {'data': None}
},
'attributes': {
'created_at': None,
'name': 'Some category'
},
'type': 'categories',
'id': '1'
}
]
}

@pytest.mark.parametrize(
('fields', 'include', 'result'),
Expand Down

0 comments on commit c9b988e

Please sign in to comment.