diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py index f100be19..0adea107 100644 --- a/graphene_sqlalchemy/enums.py +++ b/graphene_sqlalchemy/enums.py @@ -1,3 +1,4 @@ +import six from sqlalchemy.orm import ColumnProperty from sqlalchemy.types import Enum as SQLAlchemyEnumType @@ -62,7 +63,7 @@ def enum_for_field(obj_type, field_name): if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) - if not field_name or not isinstance(field_name, str): + if not field_name or not isinstance(field_name, six.string_types): raise TypeError( "Expected a field name, but got: {!r}".format(field_name)) registry = obj_type._meta.registry diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 266b5f37..a9f514ba 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,6 +1,7 @@ import warnings from functools import partial +import six from promise import Promise, is_thenable from sqlalchemy.orm.query import Query @@ -35,7 +36,7 @@ def model(self): def get_query(cls, model, info, sort=None, **args): query = get_query(model, info.context) if sort is not None: - if isinstance(sort, str): + if isinstance(sort, six.string_types): query = query.order_by(sort.value) else: query = query.order_by(*(col.value for col in sort)) diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index acfa744b..c20bc2ca 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,5 +1,6 @@ from collections import defaultdict +import six from sqlalchemy.types import Enum as SQLAlchemyEnumType from graphene import Enum @@ -42,7 +43,7 @@ def register_orm_field(self, obj_type, field_name, orm_field): raise TypeError( "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) ) - if not field_name or not isinstance(field_name, str): + if not field_name or not isinstance(field_name, six.string_types): raise TypeError("Expected a field name, but got: {!r}".format(field_name)) self._registry_orm_fields[obj_type][field_name] = orm_field