Skip to content

Commit

Permalink
Merge 28249a8 into 631513f
Browse files Browse the repository at this point in the history
  • Loading branch information
jnak committed Jan 24, 2020
2 parents 631513f + 28249a8 commit 4baae3e
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 120 deletions.
64 changes: 44 additions & 20 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,28 @@
from singledispatch import singledispatch
from sqlalchemy import types
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import interfaces
from sqlalchemy.orm import interfaces, strategies

from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
String)
from graphene.types.json import JSONString

from .batching import get_batch_resolver
from .enums import enum_for_sa_enum
from .fields import (BatchSQLAlchemyConnectionField,
default_connection_field_factory)
from .registry import get_global_registry
from .resolvers import get_attr_resolver, get_custom_resolver

try:
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
except ImportError:
ChoiceType = JSONType = ScalarListType = TSVectorType = object


is_selectin_available = getattr(strategies, 'SelectInLoader', None)


def get_column_doc(column):
return getattr(column, "doc", None)

Expand All @@ -26,29 +33,46 @@ def is_column_nullable(column):
return bool(getattr(column, "nullable", True))


def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs):
direction = relationship_prop.direction
model = relationship_prop.mapper.entity

def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching,
attr_name, orm_field_name, **field_kwargs):
"""
:param sqlalchemy.RelationshipProperty relationship_prop:
:param Registry registry:
:type function|None connection_field_factory:
:type bool batching:
:param SQLAlchemyObjectType obj_type:
:param str orm_field_name:
:rtype: Dynamic
"""
def dynamic_type():
_type = registry.get_type_for_model(model)
direction = relationship_prop.direction
model = relationship_prop.mapper.entity
type_ = obj_type._meta.registry.get_type_for_model(model)

if not _type:
batching_ = batching if is_selectin_available else False
connection_field_factory_ = connection_field_factory

if not type_:
return None

if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
return Field(
_type,
resolver=resolver,
**field_kwargs
)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if _type._meta.connection:
# TODO Add a way to override connection_field_factory
return connection_field_factory(relationship_prop, registry, **field_kwargs)
return Field(
List(_type),
**field_kwargs
)
resolver = get_custom_resolver(obj_type, orm_field_name)
if resolver is None:
resolver = get_batch_resolver(relationship_prop) if batching_ else \
get_attr_resolver(obj_type, relationship_prop.key)

return Field(type_, resolver=resolver, **field_kwargs)

if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if not type_._meta.connection:
return Field(List(type_), **field_kwargs)

if connection_field_factory_ is None:
connection_field_factory_ = BatchSQLAlchemyConnectionField.from_relationship if batching_ else \
default_connection_field_factory

# TODO Allow override of connection_field_factory and resolver via ORMField
return connection_field_factory_(relationship_prop, obj_type._meta.registry, **field_kwargs)

return Dynamic(dynamic_type)

Expand Down
Empty file removed graphene_sqlalchemy/resolver.py
Empty file.
26 changes: 26 additions & 0 deletions graphene_sqlalchemy/resolvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from graphene.utils.get_unbound_function import get_unbound_function


def get_custom_resolver(obj_type, orm_field_name):
"""
Since `graphene` will call `resolve_<field_name>` on a field only if it
does not have a `resolver`, we need to re-implement that logic here so
users are able to override the default resolvers that we provide.
"""
resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None)
if resolver:
return get_unbound_function(resolver)

return None


def get_attr_resolver(obj_type, model_attr):
"""
In order to support field renaming via `ORMField.model_attr`,
we need to define resolver functions for each field.
:param SQLAlchemyObjectType obj_type:
:param str model_attr: the name of the SQLAlchemy attribute
:rtype: Callable
"""
return lambda root, _info: getattr(root, model_attr, None)
189 changes: 184 additions & 5 deletions graphene_sqlalchemy/tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import graphene
from graphene import relay

from ..fields import BatchSQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType
from ..fields import (BatchSQLAlchemyConnectionField,
default_connection_field_factory)
from ..types import ORMField, SQLAlchemyObjectType
from .models import Article, HairKind, Pet, Reporter
from .utils import is_sqlalchemy_version_less_than, to_std_dicts

Expand Down Expand Up @@ -43,19 +44,19 @@ class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
batching = True

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
batching = True

class PetType(SQLAlchemyObjectType):
class Meta:
model = Pet
interfaces = (relay.Node,)
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
batching = True

class Query(graphene.ObjectType):
articles = graphene.Field(graphene.List(ArticleType))
Expand Down Expand Up @@ -513,3 +514,181 @@ def test_many_to_many(session_factory):
},
],
}


def test_disable_batching_via_ormfield(session_factory):
session = session_factory()
reporter_1 = Reporter(first_name='Reporter_1')
session.add(reporter_1)
reporter_2 = Reporter(first_name='Reporter_2')
session.add(reporter_2)
session.commit()
session.close()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
batching = True

favorite_article = ORMField(batching=False)
articles = ORMField(batching=False)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)

class Query(graphene.ObjectType):
reporters = graphene.Field(graphene.List(ReporterType))

def resolve_reporters(self, info):
return info.context.get('session').query(Reporter).all()

schema = graphene.Schema(query=Query)

# Test one-to-one and many-to-one relationships
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
favoriteArticle {
headline
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 2

# Test one-to-many and many-to-many relationships
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
articles {
edges {
node {
headline
}
}
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 2


def test_connection_factory_field_overrides_batching_is_false(session_factory):
session = session_factory()
reporter_1 = Reporter(first_name='Reporter_1')
session.add(reporter_1)
reporter_2 = Reporter(first_name='Reporter_2')
session.add(reporter_2)
session.commit()
session.close()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
batching = False
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship

articles = ORMField(batching=False)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)

class Query(graphene.ObjectType):
reporters = graphene.Field(graphene.List(ReporterType))

def resolve_reporters(self, info):
return info.context.get('session').query(Reporter).all()

schema = graphene.Schema(query=Query)

with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
articles {
edges {
node {
headline
}
}
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 1


def test_connection_factory_field_overrides_batching_is_true(session_factory):
session = session_factory()
reporter_1 = Reporter(first_name='Reporter_1')
session.add(reporter_1)
reporter_2 = Reporter(first_name='Reporter_2')
session.add(reporter_2)
session.commit()
session.close()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (relay.Node,)
batching = True
connection_field_factory = default_connection_field_factory

articles = ORMField(batching=True)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (relay.Node,)

class Query(graphene.ObjectType):
reporters = graphene.Field(graphene.List(ReporterType))

def resolve_reporters(self, info):
return info.context.get('session').query(Reporter).all()

schema = graphene.Schema(query=Query)

with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
# Starts new session to fully reset the engine / connection logging level
session = session_factory()
schema.execute("""
query {
reporters {
articles {
edges {
node {
headline
}
}
}
}
}
""", context_value={"session": session})
messages = sqlalchemy_logging_handler.messages

select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
assert len(select_statements) == 2
Loading

0 comments on commit 4baae3e

Please sign in to comment.