diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/graphene_sqlalchemy/__init__.py new file mode 100644 index 00000000..71b40299 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/__init__.py @@ -0,0 +1,18 @@ +from .types import SQLAlchemyObjectType, SQLAlchemyInputObjectType, SQLAlchemyInterface, SQLAlchemyMutation, SQLAlchemyAutoSchemaFactory +from .fields import SQLAlchemyConnectionField, SQLAlchemyFilteredConnectionField +from .utils import get_query, get_session + +__version__ = "2.2.0b" + +__all__ = [ + "__version__", + "SQLAlchemyObjectType", + "SQLAlchemyConnectionField", + "SQLAlchemyFilteredConnectionField", + "SQLAlchemyInputObjectType", + "SQLAlchemyInterface", + "SQLAlchemyMutation", + "SQLAlchemyAutoSchemaFactory", + "get_query", + "get_session", +] diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/graphene_sqlalchemy/converter.py new file mode 100644 index 00000000..9466cbaf --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/converter.py @@ -0,0 +1,193 @@ +from singledispatch import singledispatch +from sqlalchemy import types +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import interfaces + +from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, + String) +from graphene.types.json import JSONString + +from .enums import enum_for_sa_enum +from .registry import get_global_registry + +try: + from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType +except ImportError: + ChoiceType = JSONType = ScalarListType = TSVectorType = object + + +def get_column_doc(column): + return getattr(column, "doc", None) + + +def is_column_nullable(column): + return bool(getattr(column, "nullable", True)) + + +def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory): + direction = relationship.direction + model = relationship.mapper.entity + + def dynamic_type(): + _type = registry.get_type_for_model(model) + if not _type: + return None + if direction == interfaces.MANYTOONE or not relationship.uselist: + return Field(_type) + elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): + if _type._meta.connection: + return connection_field_factory(relationship, registry) + return Field(List(_type)) + + return Dynamic(dynamic_type) + + +def convert_sqlalchemy_hybrid_method(hybrid_item): + return String(description=getattr(hybrid_item, "__doc__", None), required=False) + + +def convert_sqlalchemy_composite(composite, registry): + converter = registry.get_converter_for_composite(composite.composite_class) + if not converter: + try: + raise Exception( + "Don't know how to convert the composite field %s (%s)" + % (composite, composite.composite_class) + ) + except AttributeError: + # handle fields that are not attached to a class yet (don't have a parent) + raise Exception( + "Don't know how to convert the composite field %r (%s)" + % (composite, composite.composite_class) + ) + return converter(composite, registry) + + +def _register_composite_class(cls, registry=None): + if registry is None: + from .registry import get_global_registry + + registry = get_global_registry() + + def inner(fn): + registry.register_composite_converter(cls, fn) + + return inner + + +convert_sqlalchemy_composite.register = _register_composite_class + + +def convert_sqlalchemy_column(column, registry=None): + return convert_sqlalchemy_type(getattr(column, "type", None), column, registry) + + +@singledispatch +def convert_sqlalchemy_type(type, column, registry=None): + raise Exception( + "Don't know how to convert the SQLAlchemy field %s (%s)" + % (column, column.__class__) + ) + + +@convert_sqlalchemy_type.register(types.Date) +@convert_sqlalchemy_type.register(types.Time) +@convert_sqlalchemy_type.register(types.String) +@convert_sqlalchemy_type.register(types.Text) +@convert_sqlalchemy_type.register(types.Unicode) +@convert_sqlalchemy_type.register(types.UnicodeText) +@convert_sqlalchemy_type.register(postgresql.UUID) +@convert_sqlalchemy_type.register(postgresql.INET) +@convert_sqlalchemy_type.register(postgresql.CIDR) +@convert_sqlalchemy_type.register(TSVectorType) +def convert_column_to_string(type, column, registry=None): + return String( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) + + +@convert_sqlalchemy_type.register(types.DateTime) +def convert_column_to_datetime(type, column, registry=None): + from graphene.types.datetime import DateTime + + return DateTime( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) + + +@convert_sqlalchemy_type.register(types.SmallInteger) +@convert_sqlalchemy_type.register(types.Integer) +def convert_column_to_int_or_id(type, column, registry=None): + if column.primary_key: + return ID( + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) + else: + return Int( + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) + + +@convert_sqlalchemy_type.register(types.Boolean) +def convert_column_to_boolean(type, column, registry=None): + return Boolean( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) + + +@convert_sqlalchemy_type.register(types.Float) +@convert_sqlalchemy_type.register(types.Numeric) +@convert_sqlalchemy_type.register(types.BigInteger) +def convert_column_to_float(type, column, registry=None): + return Float( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) + + +@convert_sqlalchemy_type.register(types.Enum) +def convert_enum_to_enum(type, column, registry=None): + return Field( + lambda: enum_for_sa_enum(type, registry or get_global_registry()), + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) + + +@convert_sqlalchemy_type.register(ChoiceType) +def convert_choice_to_enum(type, column, registry=None): + name = "{}_{}".format(column.table.name, column.name).upper() + return Enum(name, type.choices, description=get_column_doc(column)) + + +@convert_sqlalchemy_type.register(ScalarListType) +def convert_scalar_list_to_list(type, column, registry=None): + return List(String, description=get_column_doc(column)) + + +@convert_sqlalchemy_type.register(postgresql.ARRAY) +def convert_postgres_array_to_list(_type, column, registry=None): + graphene_type = convert_sqlalchemy_type(column.type.item_type, column) + inner_type = type(graphene_type) + return List( + inner_type, + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) + + +@convert_sqlalchemy_type.register(postgresql.HSTORE) +@convert_sqlalchemy_type.register(postgresql.JSON) +@convert_sqlalchemy_type.register(postgresql.JSONB) +def convert_json_to_string(type, column, registry=None): + return JSONString( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) + + +@convert_sqlalchemy_type.register(JSONType) +def convert_json_type_to_string(type, column, registry=None): + return JSONString( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/graphene_sqlalchemy/enums.py new file mode 100644 index 00000000..6b84bf52 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/enums.py @@ -0,0 +1,203 @@ +from sqlalchemy import Column +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Argument, Enum, List + +from .utils import EnumValue, to_enum_value_name, to_type_name + + +def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): + """Convert the given SQLAlchemy Enum type to a Graphene Enum type. + + The name of the Graphene Enum will be determined as follows: + If the SQLAlchemy Enum is based on a Python Enum, use the name + of the Python Enum. Otherwise, if the SQLAlchemy Enum is named, + use the SQL name after conversion to a type name. Otherwise, use + the given fallback_name or raise an error if it is empty. + + The Enum value names are converted to upper case if necessary. + """ + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) + ) + enum_class = sa_enum.enum_class + if enum_class: + if all(to_enum_value_name(key) == key for key in enum_class.__members__): + return Enum.from_enum(enum_class) + name = enum_class.__name__ + members = [ + (to_enum_value_name(key), value.value) + for key, value in enum_class.__members__.items() + ] + else: + sql_enum_name = sa_enum.name + if sql_enum_name: + name = to_type_name(sql_enum_name) + elif fallback_name: + name = fallback_name + else: + raise TypeError("No type name specified for {!r}".format(sa_enum)) + members = [(to_enum_value_name(key), key) for key in sa_enum.enums] + return Enum(name, members) + + +def enum_for_sa_enum(sa_enum, registry): + """Return the Graphene Enum type for the specified SQLAlchemy Enum type.""" + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) + ) + enum = registry.get_graphene_enum_for_sa_enum(sa_enum) + if not enum: + enum = _convert_sa_to_graphene_enum(sa_enum) + registry.register_enum(sa_enum, enum) + return enum + + +def enum_for_field(obj_type, field_name): + """Return the Graphene Enum type for the specified Graphene field.""" + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) + if not field_name or not isinstance(field_name, str): + raise TypeError( + "Expected a field name, but got: {!r}".format(field_name)) + registry = obj_type._meta.registry + orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) + if orm_field is None: + raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name)) + if not isinstance(orm_field, Column): + raise TypeError( + "{}.{} does not map to model column".format(obj_type._meta.name, field_name) + ) + sa_enum = orm_field.type + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "{}.{} does not map to enum column".format(obj_type._meta.name, field_name) + ) + enum = registry.get_graphene_enum_for_sa_enum(sa_enum) + if not enum: + fallback_name = obj_type._meta.name + to_type_name(field_name) + enum = _convert_sa_to_graphene_enum(sa_enum, fallback_name) + registry.register_enum(sa_enum, enum) + return enum + + +def _default_sort_enum_symbol_name(column_name, sort_asc=True): + return to_enum_value_name(column_name) + ("_ASC" if sort_asc else "_DESC") + + +def sort_enum_for_object_type( + obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None +): + """Return Graphene Enum for sorting the given SQLAlchemyObjectType. + + Parameters + - obj_type : SQLAlchemyObjectType + The object type for which the sort Enum shall be generated. + - name : str, optional, default None + Name to use for the sort Enum. + If not provided, it will be set to the object type name + 'SortEnum' + - only_fields : sequence, optional, default None + If this is set, only fields from this sequence will be considered. + - only_indexed : bool, optional, default False + If this is set, only indexed columns will be considered. + - get_symbol_name : function, optional, default None + Function which takes the column name and a boolean indicating + if the sort direction is ascending, and returns the symbol name + for the current column and sort direction. If no such function + is passed, a default function will be used that creates the symbols + 'foo_asc' and 'foo_desc' for a column with the name 'foo'. + + Returns + - Enum + The Graphene Enum type + """ + name = name or obj_type._meta.name + "SortEnum" + registry = obj_type._meta.registry + enum = registry.get_sort_enum_for_object_type(obj_type) + custom_options = dict( + only_fields=only_fields, + only_indexed=only_indexed, + get_symbol_name=get_symbol_name, + ) + if enum: + if name != enum.__name__ or custom_options != enum.custom_options: + raise ValueError( + "Sort enum for {} has already been customized".format(obj_type) + ) + else: + members = [] + default = [] + fields = obj_type._meta.fields + get_name = get_symbol_name or _default_sort_enum_symbol_name + for field_name in fields: + if only_fields and field_name not in only_fields: + continue + orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) + if not isinstance(orm_field, Column): + continue + if only_indexed and not (orm_field.primary_key or orm_field.index): + continue + asc_name = get_name(orm_field.name, True) + asc_value = EnumValue(asc_name, orm_field.asc()) + desc_name = get_name(orm_field.name, False) + desc_value = EnumValue(desc_name, orm_field.desc()) + if orm_field.primary_key: + default.append(asc_value) + members.extend(((asc_name, asc_value), (desc_name, desc_value))) + enum = Enum(name, members) + enum.default = default # store default as attribute + enum.custom_options = custom_options + registry.register_sort_enum(obj_type, enum) + return enum + + +def sort_argument_for_object_type( + obj_type, + enum_name=None, + only_fields=None, + only_indexed=None, + get_symbol_name=None, + has_default=True, +): + """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType. + + Parameters + - obj_type : SQLAlchemyObjectType + The object type for which the sort Argument shall be generated. + - enum_name : str, optional, default None + Name to use for the sort Enum. + If not provided, it will be set to the object type name + 'SortEnum' + - only_fields : sequence, optional, default None + If this is set, only fields from this sequence will be considered. + - only_indexed : bool, optional, default False + If this is set, only indexed columns will be considered. + - get_symbol_name : function, optional, default None + Function which takes the column name and a boolean indicating + if the sort direction is ascending, and returns the symbol name + for the current column and sort direction. If no such function + is passed, a default function will be used that creates the symbols + 'foo_asc' and 'foo_desc' for a column with the name 'foo'. + - has_default : bool, optional, default True + If this is set to False, no sorting will happen when this argument is not + passed. Otherwise results will be sortied by the primary key(s) of the model. + + Returns + - Enum + A Graphene Argument that accepts a list of sorting directions for the model. + """ + enum = sort_enum_for_object_type( + obj_type, + enum_name, + only_fields=only_fields, + only_indexed=only_indexed, + get_symbol_name=get_symbol_name, + ) + if not has_default: + enum.default = None + + return Argument(List(enum), default_value=enum.default) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/graphene_sqlalchemy/fields.py new file mode 100644 index 00000000..e0b66282 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/fields.py @@ -0,0 +1,255 @@ +import logging +from collections import OrderedDict +from functools import partial +from typing import Mapping +from uuid import UUID + +from graphene import Argument, InputObjectType, Field, List +from graphene.utils.str_converters import to_snake_case +from graphql import ResolveInfo +from promise import Promise, is_thenable +from sqlalchemy import inspect, func +from sqlalchemy.orm.query import Query + +from graphene.relay import Connection, ConnectionField +from graphene.relay.connection import PageInfo +from graphql_relay.connection.arrayconnection import connection_from_list_slice + +from graphene_sqlalchemy.converter import convert_sqlalchemy_type +from .utils import get_query + +log = logging.getLogger() + +argument_cache = {} +field_cache = {} + + +class UnsortedSQLAlchemyConnectionField(ConnectionField): + @property + def type(self, assert_type: bool= True): + from .types import SQLAlchemyObjectType, SQLAlchemyInputObjectType + from .interfaces import SQLAlchemyInterface + _type = super(ConnectionField, self).type + if issubclass(_type, Connection): + return _type + + if assert_type: + assert issubclass(_type, (SQLAlchemyObjectType, SQLAlchemyInterface, SQLAlchemyInputObjectType)), ( + "SQLALchemyConnectionField only accepts {} types, not {}" + ).format([x.__name__ for x in (SQLAlchemyObjectType, SQLAlchemyInterface, SQLAlchemyInputObjectType)], _type.__name__) + assert _type._meta.connection\ + , "The type {} doesn't have a connection".format( + _type.__name__ + ) + return _type._meta.connection + + @property + def model(self): + return self.type._meta.node._meta.model + + @classmethod + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if isinstance(sort, str): + query = query.order_by(sort.value) + else: + query = query.order_by(*(col.value for col in sort)) + return query + + @classmethod + def resolve_connection(cls, connection_type, model, info, args, resolved): + if resolved is None: + resolved = cls.get_query(model, info, **args) + if isinstance(resolved, Query): + _len = resolved.count() + else: + _len = len(resolved) + connection = connection_from_list_slice( + resolved, + args, + slice_start=0, + list_length=_len, + list_slice_length=_len, + connection_type=connection_type, + pageinfo_type=PageInfo, + edge_type=connection_type.Edge, + ) + connection.iterable = resolved + connection.length = _len + return connection + + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + resolved = resolver(root, info, **args) + + on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) + + def get_resolver(self, parent_resolver): + return partial(self.connection_resolver, parent_resolver, self.type, self.model) + + +class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): + def __init__(self, type, *args, **kwargs): + if "sort" not in kwargs and issubclass(type, Connection): + # Let super class raise if type is not a Connection + try: + kwargs.setdefault("sort", type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + type.__name__ + ) + ) + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) + + +class FilterArgument: + pass + + +class FilterField: + pass + + +def create_filter_argument(cls): + name = "{}Filter".format(cls.__name__) + if name in argument_cache: + return Argument(argument_cache[name]) + import re + + NAME_PATTERN = r"^[_a-zA-Z][_a-zA-Z0-9]*$" + COMPILED_NAME_PATTERN = re.compile(NAME_PATTERN) + fields = OrderedDict((column.name, field) + for column, field in [(column, create_filter_field(column)) + for column in inspect(cls).columns.values()] if field and COMPILED_NAME_PATTERN.match(column.name)) + argument_class: InputObjectType = type(name, (FilterArgument, InputObjectType), {}) + argument_class._meta.fields.update(fields) + argument_cache[name] = argument_class + return Argument(argument_class) + +def filter_query(query, model, field, value): + if isinstance(value, Mapping): + [(operator, value)] = value.items() + # does not work on UUID columns + if operator is "equal": + query = query.filter(getattr(model, field) == value) + elif operator is "notEqual": + query = query.filter(getattr(model, field) != value) + elif operator is "lessThan": + query = query.filter(getattr(model, field) < value) + elif operator is "greaterThan": + query = query.filter(getattr(model, field) > value) + elif operator is "like": + query = query.filter(func.lower(getattr(model, field)).like(func.lower(f'%{value}%'))) + elif operator is "in": + query = query.filter(getattr(model, field).in_(value)) + elif isinstance(value, (str, int, UUID)): + query = query.filter(getattr(model, field) == value) + else: + raise NotImplementedError(f'Filter for value type {type(value)} for {field} of model {model} is not implemented') + return query + + +def create_filter_field(column): + graphene_type = convert_sqlalchemy_type(column.type, column) + if graphene_type.__class__ == Field: + return None + + name = "{}Filter".format(str(graphene_type.__class__)) + if name in field_cache: + return Field(field_cache[name]) + + fields = OrderedDict((key, Field(graphene_type.__class__)) + for key in ["equal", "notEqual", "lessThan", "greaterThan", "like"]) + fields['in'] = Field(List(graphene_type.__class__)) + field_class: InputObjectType = type(name, (FilterField, InputObjectType), {}) + field_class._meta.fields.update(fields) + + field_cache[name] = field_class + return Field(field_class) + + +class SQLAlchemyFilteredConnectionField(SQLAlchemyConnectionField): + def __init__(self, type_, *args, **kwargs): + model = type_._meta.model + kwargs.setdefault("filter", create_filter_argument(model)) + super(SQLAlchemyFilteredConnectionField, self).__init__(type_, *args, **kwargs) + + @classmethod + def get_query(cls, model, info: ResolveInfo, filter=None, sort=None, group_by=None, order_by=None, **kwargs): + query = super().get_query(model, info, sort=None, **kwargs) + # columns = inspect(model).columns.values() + from graphene_sqlalchemy.types import SQLAlchemyInputObjectType + + for filter_name, filter_value in kwargs.items(): + model_filter_column = getattr(model, filter_name, None) + if not model_filter_column: + continue + if isinstance(filter_value, SQLAlchemyInputObjectType): + filter_model = filter_value.sqla_model + q = super().get_query(filter_model, info, sort=None, **kwargs) + # noinspection PyArgumentList + query.filter(model_filter_column == q.filter_by(**filter_value).one()) + if filter: + for filter_name, filter_value in filter.items(): + query = filter_query(query, model, filter_name, filter_value) + return query + + @classmethod + def resolve_connection(cls, connection_type, model, info, args, resolved): + filters = args.get("filter", {}) + field = getattr(info.schema._query, to_snake_case(info.field_name)) + if field and hasattr(field, 'required') and field.required: + required_filters = [rf.key for rf in field.required] + + if required_filters: + missing_filters = set(required_filters) - set(filters.keys()) + if missing_filters: + raise Exception(missing_filters) + + return super(SQLAlchemyFilteredConnectionField, cls).resolve_connection( + connection_type, model, info, args, resolved) + + +def default_connection_field_factory(relationship, registry): + model = relationship.mapper.entity + model_type = registry.get_type_for_model(model) + return createConnectionField(model_type) + + +# TODO Remove in next major version +__connectionFactory = UnsortedSQLAlchemyConnectionField + + +def createConnectionField(_type): + log.warning( + 'createConnectionField is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) + return __connectionFactory(_type) + + +def registerConnectionFieldFactory(factoryMethod): + log.warning( + 'registerConnectionFieldFactory is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) + global __connectionFactory + __connectionFactory = factoryMethod + + +def unregisterConnectionFieldFactory(): + log.warning( + 'registerConnectionFieldFactory is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) + global __connectionFactory + __connectionFactory = UnsortedSQLAlchemyConnectionField diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/interfaces.py b/graphene_sqlalchemy/graphene_sqlalchemy/interfaces.py new file mode 100644 index 00000000..0b031d87 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/interfaces.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple, Union +from uuid import UUID + +import graphene +import sqlalchemy +from graphene.relay.node import NodeField, AbstractNode +from graphene.types.interface import InterfaceOptions +from sqlalchemy.ext.declarative import DeclarativeMeta + +from .registry import Registry +from .utils import is_mapped_class + +if TYPE_CHECKING: + from typing import List + +from graphene import Field +from graphene.relay import Connection, Node +from graphene.types.utils import yank_fields_from_attrs +from .fields import default_connection_field_factory, UnsortedSQLAlchemyConnectionField +from .registry import get_global_registry +import logging + +log = logging.getLogger(__name__) + + +class SQLAlchemyInterfaceOptions(InterfaceOptions): + model: DeclarativeMeta = None + registry: Registry = None + connection: Connection = None + id: Union[str, int, UUID] = None + + +def exclude_autogenerated_sqla_columns(model: DeclarativeMeta) -> Tuple[str]: + # always pull ids out to a separate argument + autoexclude: List[str] = [] + for col in sqlalchemy.inspect(model).columns: + if ((col.primary_key and col.autoincrement) or + (isinstance(col.type, sqlalchemy.types.TIMESTAMP) and + col.server_default is not None)): + autoexclude.append(col.name) + assert isinstance(col.name, str) + return tuple(autoexclude) + + +class SQLAlchemyInterface(Node): + class SQLAlchemyInterface(Node): + @classmethod + def __init_subclass_with_meta__( + cls, + model: DeclarativeMeta = None, + registry: Registry = None, + only_fields: Tuple[str] = (), + exclude_fields: Tuple[str] = (), + connection_field_factory: UnsortedSQLAlchemyConnectionField = default_connection_field_factory, + **options + ): + _meta = SQLAlchemyInterfaceOptions(cls) + _meta.name = f'{cls.__name__}Node' + + autoexclude_columns = exclude_autogenerated_sqla_columns(model=model) + exclude_fields += autoexclude_columns + + assert is_mapped_class(model), ( + "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".' + ).format(cls.__name__, model) + + if not registry: + registry = get_global_registry() + + assert isinstance(registry, Registry), ( + "The attribute registry in {} needs to be an instance of " + 'Registry, received "{}".' + ).format(cls.__name__, registry) + from .types import construct_fields + + sqla_fields = yank_fields_from_attrs( + construct_fields( + obj_type=cls, + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + connection_field_factory=connection_field_factory + ), + _as=Field + ) + if not _meta: + _meta = SQLAlchemyInterfaceOptions(cls) + _meta.model = model + _meta.registry = registry + connection = Connection.create_type( + "{}Connection".format(cls.__name__), node=cls) + assert issubclass(connection, Connection), ( + "The connection must be a Connection. Received {}" + ).format(connection.__name__) + _meta.connection = connection + if _meta.fields: + _meta.fields.update(sqla_fields) + else: + _meta.fields = sqla_fields + _meta.fields['id'] = graphene.GlobalID(cls, description="The ID of the object.") + # call super of AbstractNode directly because it creates its own _meta, which we don't want + super(AbstractNode, cls).__init_subclass_with_meta__(_meta=_meta, **options) + + @classmethod + def Field(cls, *args, **kwargs): # noqa: N802 + return NodeField(cls, *args, **kwargs) + + @classmethod + def node_resolver(cls, only_type, root, info, id): + return cls.get_node_from_global_id(info, id, only_type=only_type) + + @classmethod + def get_node_from_global_id(cls, info, global_id, only_type=None): + try: + node: DeclarativeMeta = info.context.get('session').query(cls._meta.model).filter_by(id=global_id).one() + return node + except Exception: + return None + + @classmethod + def from_global_id(cls, global_id): + return global_id + + @classmethod + def to_global_id(cls, type, id): + return id + + @classmethod + def resolve_type(cls, instance, info): + if isinstance(instance, graphene.ObjectType): + return type(instance) + graphene_model = get_global_registry().get_type_for_model(type(instance)) + if graphene_model: + return graphene_model + else: + raise ValueError(f'{instance} must be a SQLAlchemy model or graphene.ObjectType') diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/graphene_sqlalchemy/registry.py new file mode 100644 index 00000000..9a516596 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/registry.py @@ -0,0 +1,103 @@ +from collections import defaultdict + +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Enum + + +class Registry(object): + def __init__(self): + self._registry = {} + self._registry_models = {} + self._registry_orm_fields = defaultdict(dict) + self._registry_composites = {} + self._registry_enums = {} + self._registry_sort_enums = {} + + def register(self, obj_type): + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + assert obj_type._meta.registry == self, "Registry for a Model have to match." + # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( + # 'SQLAlchemy model "{}" already associated with ' + # 'another type "{}".' + # ).format(cls._meta.model, self._registry[cls._meta.model]) + self._registry[obj_type._meta.model] = obj_type + + def get_type_for_model(self, model): + return self._registry.get(model) + + def register_orm_field(self, obj_type, field_name, orm_field, assert_type: bool = True): + from .types import SQLAlchemyObjectType, SQLAlchemyInputObjectType + from .interfaces import SQLAlchemyInterface + if assert_type: + if not isinstance(obj_type, type) or not issubclass( + obj_type, (SQLAlchemyObjectType, SQLAlchemyInterface, SQLAlchemyInputObjectType) + ): + raise TypeError( + "Expected one of {}, but got: {!r}".format([x.__name__ for x in (SQLAlchemyObjectType, SQLAlchemyInterface, SQLAlchemyInputObjectType)], obj_type) + ) + if not field_name or not isinstance(field_name, str): + raise TypeError("Expected a field name, but got: {!r}".format(field_name)) + self._registry_orm_fields[obj_type][field_name] = orm_field + + def get_orm_field_for_graphene_field(self, obj_type, field_name): + return self._registry_orm_fields.get(obj_type, {}).get(field_name) + + def register_composite_converter(self, composite, converter): + self._registry_composites[composite] = converter + + def get_converter_for_composite(self, composite): + return self._registry_composites.get(composite) + + def register_enum(self, sa_enum, graphene_enum): + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) + ) + if not isinstance(graphene_enum, type(Enum)): + raise TypeError( + "Expected Graphene Enum, but got: {!r}".format(graphene_enum) + ) + + self._registry_enums[sa_enum] = graphene_enum + + def get_graphene_enum_for_sa_enum(self, sa_enum): + return self._registry_enums.get(sa_enum) + + def register_sort_enum(self, obj_type, sort_enum): + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + if not isinstance(sort_enum, type(Enum)): + raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) + self._registry_sort_enums[obj_type] = sort_enum + + def get_sort_enum_for_object_type(self, obj_type): + return self._registry_sort_enums.get(obj_type) + + +registry = None + + +def get_global_registry(): + global registry + if not registry: + registry = Registry() + return registry + + +def reset_global_registry(): + global registry + registry = None diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/__init__.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/conftest.py new file mode 100644 index 00000000..2825eb3c --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/conftest.py @@ -0,0 +1,32 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +from ..registry import reset_global_registry +from .models import Base + +test_db_url = 'sqlite://' # use in-memory database for tests + + +@pytest.fixture(autouse=True) +def reset_registry(): + reset_global_registry() + + +@pytest.yield_fixture(scope="function") +def session(): + db = create_engine(test_db_url) + connection = db.engine.connect() + transaction = connection.begin() + Base.metadata.create_all(connection) + + # options = dict(bind=connection, binds={}) + session_factory = sessionmaker(bind=connection) + session = scoped_session(session_factory) + + yield session + + # Finalize test here + transaction.rollback() + connection.close() + session.remove() diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/models.py new file mode 100644 index 00000000..12781cc5 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/models.py @@ -0,0 +1,78 @@ +from __future__ import absolute_import + +import enum + +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import mapper, relationship + +PetKind = Enum("cat", "dog", name="pet_kind") + + +class HairKind(enum.Enum): + LONG = 'long' + SHORT = 'short' + + +Base = declarative_base() + +association_table = Table( + "association", + Base.metadata, + Column("pet_id", Integer, ForeignKey("pets.id")), + Column("reporter_id", Integer, ForeignKey("reporters.id")), +) + + +class Editor(Base): + __tablename__ = "editors" + editor_id = Column(Integer(), primary_key=True) + name = Column(String(100)) + + +class Pet(Base): + __tablename__ = "pets" + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + + +class Reporter(Base): + __tablename__ = "reporters" + id = Column(Integer(), primary_key=True) + first_name = Column(String(30)) + last_name = Column(String(30)) + email = Column(String()) + favorite_pet_kind = Column(PetKind) + pets = relationship("Pet", secondary=association_table, backref="reporters") + articles = relationship("Article", backref="reporter") + favorite_article = relationship("Article", uselist=False) + + # total = column_property( + # select([ + # func.cast(func.count(PersonInfo.id), Float) + # ]) + # ) + + +class Article(Base): + __tablename__ = "articles" + id = Column(Integer(), primary_key=True) + headline = Column(String(100)) + pub_date = Column(Date()) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) + + +class ReflectedEditor(type): + """Same as Editor, but using reflected table.""" + + @classmethod + def __subclasses__(cls): + return [] + + +editor_table = Table("editors", Base.metadata, autoload=True) + +mapper(ReflectedEditor, editor_table) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_converter.py new file mode 100644 index 00000000..f38999d2 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_converter.py @@ -0,0 +1,372 @@ +import enum + +import pytest +from sqlalchemy import Column, Table, case, func, select, types +from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import column_property, composite +from sqlalchemy.sql.elements import Label +from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType + +import graphene +from graphene.relay import Node +from graphene.types.datetime import DateTime +from graphene.types.json import JSONString + +from ..converter import (convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_relationship) +from ..fields import (UnsortedSQLAlchemyConnectionField, + default_connection_field_factory) +from ..registry import Registry +from ..types import SQLAlchemyObjectType +from .models import Article, Pet, Reporter + + +def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): + column = Column(sqlalchemy_type, doc="Custom Help Text", **kwargs) + graphene_type = convert_sqlalchemy_column(column) + assert isinstance(graphene_type, graphene_field) + field = ( + graphene_type + if isinstance(graphene_type, graphene.Field) + else graphene_type.Field() + ) + assert field.description == "Custom Help Text" + return field + + +def assert_composite_conversion( + composite_class, composite_columns, graphene_field, registry, **kwargs +): + composite_column = composite( + composite_class, *composite_columns, doc="Custom Help Text", **kwargs + ) + graphene_type = convert_sqlalchemy_composite(composite_column, registry) + assert isinstance(graphene_type, graphene_field) + field = graphene_type.Field() + # SQLAlchemy currently does not persist the doc onto the column, even though + # the documentation says it does.... + # assert field.description == 'Custom Help Text' + return field + + +def test_should_unknown_sqlalchemy_field_raise_exception(): + re_err = "Don't know how to convert the SQLAlchemy field" + with pytest.raises(Exception, match=re_err): + convert_sqlalchemy_column(None) + + +def test_should_date_convert_string(): + assert_column_conversion(types.Date(), graphene.String) + + +def test_should_datetime_convert_string(): + assert_column_conversion(types.DateTime(), DateTime) + + +def test_should_time_convert_string(): + assert_column_conversion(types.Time(), graphene.String) + + +def test_should_string_convert_string(): + assert_column_conversion(types.String(), graphene.String) + + +def test_should_text_convert_string(): + assert_column_conversion(types.Text(), graphene.String) + + +def test_should_unicode_convert_string(): + assert_column_conversion(types.Unicode(), graphene.String) + + +def test_should_unicodetext_convert_string(): + assert_column_conversion(types.UnicodeText(), graphene.String) + + +def test_should_enum_convert_enum(): + field = assert_column_conversion( + types.Enum(enum.Enum("TwoNumbers", ("one", "two"))), graphene.Field + ) + field_type = field.type() + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + + field = assert_column_conversion( + types.Enum("one", "two", name="two_numbers"), graphene.Field + ) + field_type = field.type() + assert field_type._meta.name == "TwoNumbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + + +def test_should_not_enum_convert_enum_without_name(): + field = assert_column_conversion( + types.Enum("one", "two"), graphene.Field + ) + re_err = r"No type name specified for Enum\('one', 'two'\)" + with pytest.raises(TypeError, match=re_err): + field.type() + + +def test_should_small_integer_convert_int(): + assert_column_conversion(types.SmallInteger(), graphene.Int) + + +def test_should_big_integer_convert_int(): + assert_column_conversion(types.BigInteger(), graphene.Float) + + +def test_should_integer_convert_int(): + assert_column_conversion(types.Integer(), graphene.Int) + + +def test_should_integer_convert_id(): + assert_column_conversion(types.Integer(), graphene.ID, primary_key=True) + + +def test_should_boolean_convert_boolean(): + assert_column_conversion(types.Boolean(), graphene.Boolean) + + +def test_should_float_convert_float(): + assert_column_conversion(types.Float(), graphene.Float) + + +def test_should_numeric_convert_float(): + assert_column_conversion(types.Numeric(), graphene.Float) + + +def test_should_label_convert_string(): + label = Label("label_test", case([], else_="foo"), type_=types.Unicode()) + graphene_type = convert_sqlalchemy_column(label) + assert isinstance(graphene_type, graphene.String) + + +def test_should_label_convert_int(): + label = Label("int_label_test", case([], else_="foo"), type_=types.Integer()) + graphene_type = convert_sqlalchemy_column(label) + assert isinstance(graphene_type, graphene.Int) + + +def test_should_choice_convert_enum(): + TYPES = [(u"es", u"Spanish"), (u"en", u"English")] + column = Column(ChoiceType(TYPES), doc="Language", name="language") + Base = declarative_base() + + Table("translatedmodel", Base.metadata, column) + graphene_type = convert_sqlalchemy_column(column) + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "TRANSLATEDMODEL_LANGUAGE" + assert graphene_type._meta.description == "Language" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" + + +def test_should_columproperty_convert(): + + Base = declarative_base() + + class Test(Base): + __tablename__ = "test" + id = Column(types.Integer, primary_key=True) + column = column_property( + select([func.sum(func.cast(id, types.Integer))]).where(id == 1) + ) + + graphene_type = convert_sqlalchemy_column(Test.column) + assert not graphene_type.kwargs["required"] + + +def test_should_scalar_list_convert_list(): + assert_column_conversion(ScalarListType(), graphene.List) + + +def test_should_jsontype_convert_jsonstring(): + assert_column_conversion(JSONType(), JSONString) + + +def test_should_manytomany_convert_connectionorlist(): + registry = Registry() + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, registry, default_connection_field_factory + ) + assert isinstance(dynamic_field, graphene.Dynamic) + assert not dynamic_field.get_type() + + +def test_should_manytomany_convert_connectionorlist_list(): + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, A._meta.registry, default_connection_field_factory + ) + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() + assert isinstance(graphene_type, graphene.Field) + assert isinstance(graphene_type.type, graphene.List) + assert graphene_type.type.of_type == A + + +def test_should_manytomany_convert_connectionorlist_connection(): + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, A._meta.registry, default_connection_field_factory + ) + assert isinstance(dynamic_field, graphene.Dynamic) + assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) + + +def test_should_manytoone_convert_connectionorlist(): + registry = Registry() + dynamic_field = convert_sqlalchemy_relationship( + Article.reporter.property, registry, default_connection_field_factory + ) + assert isinstance(dynamic_field, graphene.Dynamic) + assert not dynamic_field.get_type() + + +def test_should_manytoone_convert_connectionorlist_list(): + class A(SQLAlchemyObjectType): + class Meta: + model = Reporter + + dynamic_field = convert_sqlalchemy_relationship( + Article.reporter.property, A._meta.registry, default_connection_field_factory + ) + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() + assert isinstance(graphene_type, graphene.Field) + assert graphene_type.type == A + + +def test_should_manytoone_convert_connectionorlist_connection(): + class A(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + dynamic_field = convert_sqlalchemy_relationship( + Article.reporter.property, A._meta.registry, default_connection_field_factory + ) + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() + assert isinstance(graphene_type, graphene.Field) + assert graphene_type.type == A + + +def test_should_onetoone_convert_field(): + class A(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.favorite_article.property, + A._meta.registry, + default_connection_field_factory, + ) + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() + assert isinstance(graphene_type, graphene.Field) + assert graphene_type.type == A + + +def test_should_postgresql_uuid_convert(): + assert_column_conversion(postgresql.UUID(), graphene.String) + + +def test_should_postgresql_enum_convert(): + field = assert_column_conversion( + postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field + ) + field_type = field.type() + assert field_type._meta.name == "TwoNumbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + + +def test_should_postgresql_py_enum_convert(): + field = assert_column_conversion( + postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), + graphene.Field, + ) + field_type = field.type() + assert field_type._meta.name == "TwoNumbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + + +def test_should_postgresql_array_convert(): + assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List) + + +def test_should_postgresql_json_convert(): + assert_column_conversion(postgresql.JSON(), JSONString) + + +def test_should_postgresql_jsonb_convert(): + assert_column_conversion(postgresql.JSONB(), JSONString) + + +def test_should_postgresql_hstore_convert(): + assert_column_conversion(postgresql.HSTORE(), JSONString) + + +def test_should_composite_convert(): + class CompositeClass: + def __init__(self, col1, col2): + self.col1 = col1 + self.col2 = col2 + + registry = Registry() + + @convert_sqlalchemy_composite.register(CompositeClass, registry) + def convert_composite_class(composite, registry): + return graphene.String(description=composite.doc) + + assert_composite_conversion( + CompositeClass, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + graphene.String, + registry, + ) + + +def test_should_unknown_sqlalchemy_composite_raise_exception(): + registry = Registry() + + re_err = "Don't know how to convert the composite field" + with pytest.raises(Exception, match=re_err): + + class CompositeClass(object): + def __init__(self, col1, col2): + self.col1 = col1 + self.col2 = col2 + + assert_composite_conversion( + CompositeClass, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + graphene.String, + registry, + ) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_enums.py new file mode 100644 index 00000000..ca376964 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_enums.py @@ -0,0 +1,122 @@ +from enum import Enum as PyEnum + +import pytest +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Enum + +from ..enums import _convert_sa_to_graphene_enum, enum_for_field +from ..types import SQLAlchemyObjectType +from .models import HairKind, Pet + + +def test_convert_sa_to_graphene_enum_bad_type(): + re_err = "Expected sqlalchemy.types.Enum, but got: 'foo'" + with pytest.raises(TypeError, match=re_err): + _convert_sa_to_graphene_enum("foo") + + +def test_convert_sa_to_graphene_enum_based_on_py_enum(): + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + sa_enum = SQLAlchemyEnumType(Color) + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "Color" + assert graphene_enum._meta.enum is Color + + +def test_convert_sa_to_graphene_enum_based_on_py_enum_with_bad_names(): + class Color(PyEnum): + red = 1 + green = 2 + blue = 3 + + sa_enum = SQLAlchemyEnumType(Color) + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "Color" + assert graphene_enum._meta.enum is not Color + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 1), ("GREEN", 2), ("BLUE", 3)] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue", name="color_values") + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "ColorValues" + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue") + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "FallbackName" + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue") + re_err = r"No type name specified for Enum\('red', 'green', 'blue'\)" + with pytest.raises(TypeError, match=re_err): + _convert_sa_to_graphene_enum(sa_enum) + + +def test_enum_for_field(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + enum = enum_for_field(PetType, 'pet_kind') + assert isinstance(enum, type(Enum)) + assert enum._meta.name == "PetKind" + assert [ + (key, value.value) + for key, value in enum._meta.enum.__members__.items() + ] == [("CAT", 'cat'), ("DOG", 'dog')] + enum2 = enum_for_field(PetType, 'pet_kind') + assert enum2 is enum + enum2 = PetType.enum_for_field('pet_kind') + assert enum2 is enum + + enum = enum_for_field(PetType, 'hair_kind') + assert isinstance(enum, type(Enum)) + assert enum._meta.name == "HairKind" + assert enum._meta.enum is HairKind + enum2 = PetType.enum_for_field('hair_kind') + assert enum2 is enum + + re_err = r"Cannot get PetType\.other_kind" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, 'other_kind') + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field('other_kind') + + re_err = r"PetType\.name does not map to enum column" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, 'name') + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field('name') + + re_err = r"Expected a field name, but got: None" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, None) + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field(None) + + re_err = "Expected SQLAlchemyObjectType, but got: None" + with pytest.raises(TypeError, match=re_err): + enum_for_field(None, 'other_kind') diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_fields.py new file mode 100644 index 00000000..0f8738f0 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_fields.py @@ -0,0 +1,44 @@ +import pytest + +from graphene.relay import Connection + +from ..fields import SQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from .models import Editor as EditorModel +from .models import Pet as PetModel + + +class Pet(SQLAlchemyObjectType): + class Meta: + model = PetModel + + +class Editor(SQLAlchemyObjectType): + class Meta: + model = EditorModel + + +class PetConn(Connection): + class Meta: + node = Pet + + +def test_sort_added_by_default(): + field = SQLAlchemyConnectionField(PetConn) + assert "sort" in field.args + assert field.args["sort"] == Pet.sort_argument() + + +def test_sort_can_be_removed(): + field = SQLAlchemyConnectionField(PetConn, sort=None) + assert "sort" not in field.args + + +def test_custom_sort(): + field = SQLAlchemyConnectionField(PetConn, sort=Editor.sort_argument()) + assert field.args["sort"] == Editor.sort_argument() + + +def test_init_raises(): + with pytest.raises(TypeError, match="Cannot create sort"): + SQLAlchemyConnectionField(Connection) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_query.py new file mode 100644 index 00000000..5279bd87 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_query.py @@ -0,0 +1,285 @@ +import graphene +from graphene.relay import Connection, Node + +from ..fields import SQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from .models import Article, Editor, HairKind, Pet, Reporter + + +def to_std_dicts(value): + """Convert nested ordered dicts to normal dicts for better comparison.""" + if isinstance(value, dict): + return {k: to_std_dicts(v) for k, v in value.items()} + elif isinstance(value, list): + return [to_std_dicts(v) for v in value] + else: + return value + + +def add_test_data(session): + reporter = Reporter( + first_name='John', last_name='Doe', favorite_pet_kind='cat') + session.add(reporter) + pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) + session.add(pet) + pet.reporters.append(reporter) + article = Article(headline='Hi!') + article.reporter = reporter + session.add(article) + reporter = Reporter( + first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + session.add(reporter) + pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet.reporters.append(reporter) + session.add(pet) + editor = Editor(name="Jack") + session.add(editor) + session.commit() + + +def test_should_query_well(session): + add_test_data(session) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + reporters = graphene.List(ReporterType) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + def resolve_reporters(self, _info): + return session.query(Reporter) + + query = """ + query ReporterQuery { + reporter { + firstName + lastName + email + } + reporters { + firstName + } + } + """ + expected = { + "reporter": {"firstName": "John", "lastName": "Doe", "email": None}, + "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_should_query_node(session): + add_test_data(session) + + class ReporterNode(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + @classmethod + def get_node(cls, info, id): + return Reporter(id=2, first_name="Cookie Monster") + + class ArticleNode(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + class ArticleConnection(Connection): + class Meta: + node = ArticleNode + + class Query(graphene.ObjectType): + node = Node.Field() + reporter = graphene.Field(ReporterNode) + article = graphene.Field(ArticleNode) + all_articles = SQLAlchemyConnectionField(ArticleConnection) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + def resolve_article(self, _info): + return session.query(Article).first() + + query = """ + query ReporterQuery { + reporter { + id + firstName, + articles { + edges { + node { + headline + } + } + } + lastName, + email + } + allArticles { + edges { + node { + headline + } + } + } + myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { + id + ... on ReporterNode { + firstName + } + ... on ArticleNode { + headline + } + } + } + """ + expected = { + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "John", + "lastName": "Doe", + "email": None, + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, + "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_should_custom_identifier(session): + add_test_data(session) + + class EditorNode(SQLAlchemyObjectType): + class Meta: + model = Editor + interfaces = (Node,) + + class EditorConnection(Connection): + class Meta: + node = EditorNode + + class Query(graphene.ObjectType): + node = Node.Field() + all_editors = SQLAlchemyConnectionField(EditorConnection) + + query = """ + query EditorQuery { + allEditors { + edges { + node { + id + name + } + } + }, + node(id: "RWRpdG9yTm9kZTox") { + ...on EditorNode { + name + } + } + } + """ + expected = { + "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "Jack"}}]}, + "node": {"name": "Jack"}, + } + + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_should_mutate_well(session): + add_test_data(session) + + class EditorNode(SQLAlchemyObjectType): + class Meta: + model = Editor + interfaces = (Node,) + + class ReporterNode(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + @classmethod + def get_node(cls, id, info): + return Reporter(id=2, first_name="Cookie Monster") + + class ArticleNode(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + class CreateArticle(graphene.Mutation): + class Arguments: + headline = graphene.String() + reporter_id = graphene.ID() + + ok = graphene.Boolean() + article = graphene.Field(ArticleNode) + + def mutate(self, info, headline, reporter_id): + new_article = Article(headline=headline, reporter_id=reporter_id) + + session.add(new_article) + session.commit() + ok = True + + return CreateArticle(article=new_article, ok=ok) + + class Query(graphene.ObjectType): + node = Node.Field() + + class Mutation(graphene.ObjectType): + create_article = CreateArticle.Field() + + query = """ + mutation ArticleCreator { + createArticle( + headline: "My Article" + reporterId: "1" + ) { + ok + article { + headline + reporter { + id + firstName + } + } + } + } + """ + expected = { + "createArticle": { + "ok": True, + "article": { + "headline": "My Article", + "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John"}, + }, + } + } + + schema = graphene.Schema(query=Query, mutation=Mutation) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_query_enums.py new file mode 100644 index 00000000..ec585d57 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_query_enums.py @@ -0,0 +1,198 @@ +import graphene + +from ..types import SQLAlchemyObjectType +from .models import HairKind, Pet, Reporter +from .test_query import add_test_data, to_std_dicts + + +def test_query_pet_kinds(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + + class Meta: + model = Pet + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + reporters = graphene.List(ReporterType) + pets = graphene.List(PetType, kind=graphene.Argument( + PetType.enum_for_field('pet_kind'))) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + def resolve_reporters(self, _info): + return session.query(Reporter) + + def resolve_pets(self, _info, kind): + query = session.query(Pet) + if kind: + query = query.filter_by(pet_kind=kind) + return query + + query = """ + query ReporterQuery { + reporter { + firstName + lastName + email + favoritePetKind + pets { + name + petKind + } + } + reporters { + firstName + favoritePetKind + } + pets(kind: DOG) { + name + petKind + } + } + """ + expected = { + 'reporter': { + 'firstName': 'John', + 'lastName': 'Doe', + 'email': None, + 'favoritePetKind': 'CAT', + 'pets': [{ + 'name': 'Garfield', + 'petKind': 'CAT' + }] + }, + 'reporters': [{ + 'firstName': 'John', + 'favoritePetKind': 'CAT', + }, { + 'firstName': 'Jane', + 'favoritePetKind': 'DOG', + }], + 'pets': [{ + 'name': 'Lassie', + 'petKind': 'DOG' + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_query_more_enums(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType) + + def resolve_pet(self, _info): + return session.query(Pet).first() + + query = """ + query PetQuery { + pet { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_enum_as_argument(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field( + PetType, + kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) + + def resolve_pet(self, info, kind=None): + query = session.query(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind) + return query.first() + + query = """ + query PetQuery($kind: PetKind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "CAT"}) + assert not result.errors + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected + result = schema.execute(query, variables={"kind": "DOG"}) + assert not result.errors + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + result = to_std_dicts(result.data) + assert result == expected + + +def test_py_enum_as_argument(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field( + PetType, + kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), + ) + + def resolve_pet(self, _info, kind=None): + query = session.query(Pet) + if kind: + # enum arguments are expected to be strings, not PyEnums + query = query.filter(Pet.hair_kind == HairKind(kind)) + return query.first() + + query = """ + query PetQuery($kind: HairKind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "SHORT"}) + assert not result.errors + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected + result = schema.execute(query, variables={"kind": "LONG"}) + assert not result.errors + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + result = to_std_dicts(result.data) + assert result == expected diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_reflected.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_reflected.py new file mode 100644 index 00000000..46e10de9 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_reflected.py @@ -0,0 +1,20 @@ + +from graphene import ObjectType + +from ..registry import Registry +from ..types import SQLAlchemyObjectType +from .models import ReflectedEditor + +registry = Registry() + + +class Reflected(SQLAlchemyObjectType): + class Meta: + model = ReflectedEditor + registry = registry + + +def test_objecttype_registered(): + assert issubclass(Reflected, ObjectType) + assert Reflected._meta.model == ReflectedEditor + assert list(Reflected._meta.fields.keys()) == ["editor_id", "name"] diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_registry.py new file mode 100644 index 00000000..71d225a8 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_registry.py @@ -0,0 +1,128 @@ +import pytest +from sqlalchemy.types import Enum as SQLAlchemyEnum + +from graphene import Enum as GrapheneEnum + +from ..registry import Registry +from ..types import SQLAlchemyObjectType +from ..utils import EnumValue +from .models import Pet + + +def test_register_object_type(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + reg.register(PetType) + assert reg.get_type_for_model(Pet) is PetType + + +def test_register_incorrect_object_type(): + reg = Registry() + + class Spam: + pass + + re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register(Spam) + + +def test_register_orm_field(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + reg.register_orm_field(PetType, "name", Pet.name) + assert reg.get_orm_field_for_graphene_field(PetType, "name") is Pet.name + + +def test_register_orm_field_incorrect_types(): + reg = Registry() + + class Spam: + pass + + re_err = r"Expected one of \['SQLAlchemyObjectType', 'SQLAlchemyInterface', 'SQLAlchemyInputObjectType'\], but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register_orm_field(Spam, "name", Pet.name) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + re_err = "Expected a field name, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register_orm_field(PetType, Spam, Pet.name) + + +def test_register_enum(): + reg = Registry() + + sa_enum = SQLAlchemyEnum("cat", "dog") + graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) + + reg.register_enum(sa_enum, graphene_enum) + assert reg.get_graphene_enum_for_sa_enum(sa_enum) is graphene_enum + + +def test_register_enum_incorrect_types(): + reg = Registry() + + sa_enum = SQLAlchemyEnum("cat", "dog") + graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) + + re_err = r"Expected Graphene Enum, but got: Enum\('cat', 'dog'\)" + with pytest.raises(TypeError, match=re_err): + reg.register_enum(sa_enum, sa_enum) + + re_err = r"Expected SQLAlchemyEnumType, but got: .*PetKind.*" + with pytest.raises(TypeError, match=re_err): + reg.register_enum(graphene_enum, graphene_enum) + + +def test_register_sort_enum(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + sort_enum = GrapheneEnum( + "PetSort", + [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], + ) + + reg.register_sort_enum(PetType, sort_enum) + assert reg.get_sort_enum_for_object_type(PetType) is sort_enum + + +def test_register_sort_enum_incorrect_types(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + sort_enum = GrapheneEnum( + "PetSort", + [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], + ) + + re_err = r"Expected SQLAlchemyObjectType, but got: .*PetSort.*" + with pytest.raises(TypeError, match=re_err): + reg.register_sort_enum(sort_enum, sort_enum) + + re_err = r"Expected Graphene Enum, but got: .*PetType.*" + with pytest.raises(TypeError, match=re_err): + reg.register_sort_enum(PetType, PetType) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_schema.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_schema.py new file mode 100644 index 00000000..87739bdb --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_schema.py @@ -0,0 +1,50 @@ +from py.test import raises + +from ..registry import Registry +from ..types import SQLAlchemyObjectType +from .models import Reporter + + +def test_should_raise_if_no_model(): + with raises(Exception) as excinfo: + + class Character1(SQLAlchemyObjectType): + pass + + assert "valid SQLAlchemy Model" in str(excinfo.value) + + +def test_should_raise_if_model_is_invalid(): + with raises(Exception) as excinfo: + + class Character2(SQLAlchemyObjectType): + class Meta: + model = 1 + + assert "valid SQLAlchemy Model" in str(excinfo.value) + + +def test_should_map_fields_correctly(): + class ReporterType2(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = Registry() + + assert list(ReporterType2._meta.fields.keys()) == [ + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + "pets", + "articles", + "favorite_article", + ] + + +def test_should_map_only_few_fields(): + class Reporter2(SQLAlchemyObjectType): + class Meta: + model = Reporter + only_fields = ("id", "email") + assert list(Reporter2._meta.fields.keys()) == ["id", "email"] diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_sort_enums.py new file mode 100644 index 00000000..1eb106da --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_sort_enums.py @@ -0,0 +1,389 @@ +import pytest +import sqlalchemy as sa + +from graphene import Argument, Enum, List, ObjectType, Schema +from graphene.relay import Connection, Node + +from ..fields import SQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from ..utils import to_type_name +from .models import Base, HairKind, Pet +from .test_query import to_std_dicts + + +def add_pets(session): + pets = [ + Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), + Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), + Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), + ] + session.add_all(pets) + session.commit() + + +def test_sort_enum(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum() + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "PetTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + "REPORTER_ID_ASC", + "REPORTER_ID_DESC", + ] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" + assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" + assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" + + +def test_sort_enum_with_custom_name(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum(name="CustomSortName") + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "CustomSortName" + + +def test_sort_enum_cache(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum() + sort_enum_2 = PetType.sort_enum() + assert sort_enum_2 is sort_enum + sort_enum_2 = PetType.sort_enum(name="PetTypeSortEnum") + assert sort_enum_2 is sort_enum + err_msg = "Sort enum for PetType has already been customized" + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(name="CustomSortName") + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(only_fields=["id"]) + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(only_indexed=True) + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(get_symbol_name=lambda: "foo") + + +def test_sort_enum_with_excluded_field_in_object_type(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + exclude_fields = ["reporter_id"] + + sort_enum = PetType.sort_enum() + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + ] + + +def test_sort_enum_only_fields(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum(only_fields=["id", "name"]) + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + ] + + +def test_sort_argument(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_arg = PetType.sort_argument() + assert isinstance(sort_arg, Argument) + + assert isinstance(sort_arg.type, List) + sort_enum = sort_arg.type._of_type + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "PetTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + "REPORTER_ID_ASC", + "REPORTER_ID_DESC", + ] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" + assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" + assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" + + assert sort_arg.default_value == ["ID_ASC"] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + + +def test_sort_argument_with_excluded_fields_in_object_type(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + exclude_fields = ["hair_kind", "reporter_id"] + + sort_arg = PetType.sort_argument() + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_only_fields(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + only_fields = ["id", "pet_kind"] + + sort_arg = PetType.sort_argument() + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_for_multi_column_pk(): + class MultiPkTestModel(Base): + __tablename__ = "multi_pk_test_table" + foo = sa.Column(sa.Integer, primary_key=True) + bar = sa.Column(sa.Integer, primary_key=True) + + class MultiPkTestType(SQLAlchemyObjectType): + class Meta: + model = MultiPkTestModel + + sort_arg = MultiPkTestType.sort_argument() + assert sort_arg.default_value == ["FOO_ASC", "BAR_ASC"] + + +def test_sort_argument_only_indexed(): + class IndexedTestModel(Base): + __tablename__ = "indexed_test_table" + id = sa.Column(sa.Integer, primary_key=True) + foo = sa.Column(sa.Integer, index=False) + bar = sa.Column(sa.Integer, index=True) + + class IndexedTestType(SQLAlchemyObjectType): + class Meta: + model = IndexedTestModel + + sort_arg = IndexedTestType.sort_argument(only_indexed=True) + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "BAR_ASC", + "BAR_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_with_custom_symbol_names(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + def get_symbol_name(column_name, sort_asc=True): + return to_type_name(column_name) + ("Up" if sort_asc else "Down") + + sort_arg = PetType.sort_argument(get_symbol_name=get_symbol_name) + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "IdUp", + "IdDown", + "NameUp", + "NameDown", + "PetKindUp", + "PetKindDown", + "HairKindUp", + "HairKindDown", + "ReporterIdUp", + "ReporterIdDown", + ] + assert sort_arg.default_value == ["IdUp"] + + +def test_sort_query(session): + add_pets(session) + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + + class PetConnection(Connection): + class Meta: + node = PetNode + + class Query(ObjectType): + defaultSort = SQLAlchemyConnectionField(PetConnection) + nameSort = SQLAlchemyConnectionField(PetConnection) + multipleSort = SQLAlchemyConnectionField(PetConnection) + descSort = SQLAlchemyConnectionField(PetConnection) + singleColumnSort = SQLAlchemyConnectionField( + PetConnection, sort=Argument(PetNode.sort_enum()) + ) + noDefaultSort = SQLAlchemyConnectionField( + PetConnection, sort=PetNode.sort_argument(has_default=False) + ) + noSort = SQLAlchemyConnectionField(PetConnection, sort=None) + + query = """ + query sortTest { + defaultSort { + edges { + node { + name + } + } + } + nameSort(sort: NAME_ASC) { + edges { + node { + name + } + } + } + multipleSort(sort: [PET_KIND_ASC, NAME_DESC]) { + edges { + node { + name + petKind + } + } + } + descSort(sort: [NAME_DESC]) { + edges { + node { + name + } + } + } + singleColumnSort(sort: NAME_DESC) { + edges { + node { + name + } + } + } + noDefaultSort(sort: NAME_ASC) { + edges { + node { + name + } + } + } + } + """ + + def makeNodes(nodeList): + nodes = [{"node": item} for item in nodeList] + return {"edges": nodes} + + expected = { + "defaultSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), + "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), + "noDefaultSort": makeNodes( + [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] + ), + "multipleSort": makeNodes( + [ + {"name": "Alf", "petKind": "CAT"}, + {"name": "Lassie", "petKind": "DOG"}, + {"name": "Barf", "petKind": "DOG"}, + ] + ), + "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), + "singleColumnSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), + } # yapf: disable + + schema = Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + queryError = """ + query sortTest { + singleColumnSort(sort: [PET_KIND_ASC, NAME_DESC]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(queryError, context_value={"session": session}) + assert result.errors is not None + assert '"sort" has invalid value' in result.errors[0].message + + queryNoSort = """ + query sortTest { + noDefaultSort { + edges { + node { + name + } + } + } + noSort { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute(queryNoSort, context_value={"session": session}) + assert not result.errors + # TODO: SQLite usually returns the results ordered by primary key, + # so we cannot test this way whether sorting actually happens or not. + # Also, no sort order is guaranteed by SQLite if "no order" by is used. + assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [ + node["node"]["name"] for node in result.data["noDefaultSort"]["edges"] + ] diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_types.py new file mode 100644 index 00000000..b76136fb --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_types.py @@ -0,0 +1,282 @@ +from collections import OrderedDict + +import six # noqa F401 +from promise import Promise + +from graphene import (Connection, Field, Int, Interface, Node, ObjectType, + is_node) + +from ..fields import (SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField, + registerConnectionFieldFactory, + unregisterConnectionFieldFactory) +from ..registry import Registry +from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions +from .models import Article, Reporter + +registry = Registry() + + +class Character(SQLAlchemyObjectType): + """Character description""" + + class Meta: + model = Reporter + registry = registry + + +class Human(SQLAlchemyObjectType): + """Human description""" + + pub_date = Int() + + class Meta: + model = Article + exclude_fields = ("id",) + registry = registry + interfaces = (Node,) + + +def test_sqlalchemy_interface(): + assert issubclass(Node, Interface) + assert issubclass(Node, Node) + + +# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) +# def test_sqlalchemy_get_node(get): +# human = Human.get_node(1, None) +# get.assert_called_with(id=1) +# assert human.id == 1 + + +def test_objecttype_registered(): + assert issubclass(Character, ObjectType) + assert Character._meta.model == Reporter + assert list(Character._meta.fields.keys()) == [ + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + "pets", + "articles", + "favorite_article", + ] + + +# def test_sqlalchemynode_idfield(): +# idfield = Node._meta.fields_map['id'] +# assert isinstance(idfield, GlobalIDField) + + +# def test_node_idfield(): +# idfield = Human._meta.fields_map['id'] +# assert isinstance(idfield, GlobalIDField) + + +def test_node_replacedfield(): + idfield = Human._meta.fields["pub_date"] + assert isinstance(idfield, Field) + assert idfield.type == Int + + +def test_object_type(): + class Human(SQLAlchemyObjectType): + """Human description""" + + pub_date = Int() + + class Meta: + model = Article + # exclude_fields = ('id', ) + registry = registry + interfaces = (Node,) + + assert issubclass(Human, ObjectType) + assert list(Human._meta.fields.keys()) == [ + "id", + "headline", + "pub_date", + "reporter_id", + "reporter", + ] + assert is_node(Human) + + +# Test Custom SQLAlchemyObjectType Implementation +class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): + class Meta: + abstract = True + + +class CustomCharacter(CustomSQLAlchemyObjectType): + """Character description""" + + class Meta: + model = Reporter + registry = registry + + +def test_custom_objecttype_registered(): + assert issubclass(CustomCharacter, ObjectType) + assert CustomCharacter._meta.model == Reporter + assert list(CustomCharacter._meta.fields.keys()) == [ + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + "pets", + "articles", + "favorite_article", + ] + + +# Test Custom SQLAlchemyObjectType with Custom Options +class CustomOptions(SQLAlchemyObjectTypeOptions): + custom_option = None + custom_fields = None + + +class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__( + cls, custom_option=None, custom_fields=None, **options + ): + _meta = CustomOptions(cls) + _meta.custom_option = custom_option + _meta.fields = custom_fields + super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + +class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): + class Meta: + model = Reporter + custom_option = "custom_option" + custom_fields = OrderedDict([("custom_field", Field(Int()))]) + + +def test_objecttype_with_custom_options(): + assert issubclass(ReporterWithCustomOptions, ObjectType) + assert ReporterWithCustomOptions._meta.model == Reporter + assert list(ReporterWithCustomOptions._meta.fields.keys()) == [ + "custom_field", + "id", + "first_name", + "last_name", + "email", + "favorite_pet_kind", + "pets", + "articles", + "favorite_article", + ] + assert ReporterWithCustomOptions._meta.custom_option == "custom_option" + assert isinstance(ReporterWithCustomOptions._meta.fields["custom_field"].type, Int) + + +def test_promise_connection_resolver(): + class TestConnection(Connection): + class Meta: + node = ReporterWithCustomOptions + + def resolver(_obj, _info): + return Promise.resolve([]) + + result = SQLAlchemyConnectionField.connection_resolver( + resolver, TestConnection, ReporterWithCustomOptions, None, None + ) + assert result is not None + + +# Tests for connection_field_factory + +class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): + pass + + +def test_default_connection_field_factory(): + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) + + +def test_register_connection_field_factory(): + def test_connection_field_factory(relationship, registry): + model = relationship.mapper.entity + _type = registry.get_type_for_model(model) + return _TestSQLAlchemyConnectionField(_type._meta.connection) + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + connection_field_factory = test_connection_field_factory + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + + +def test_deprecated_registerConnectionFieldFactory(): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + + +def test_deprecated_unregisterConnectionFieldFactory(): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + unregisterConnectionFieldFactory() + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_utils.py new file mode 100644 index 00000000..e13d919c --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/tests/test_utils.py @@ -0,0 +1,101 @@ +import pytest +import sqlalchemy as sa + +from graphene import Enum, List, ObjectType, Schema, String + +from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model, + to_enum_value_name, to_type_name) +from .models import Base, Editor, Pet + + +def test_get_session(): + session = "My SQLAlchemy session" + + class Query(ObjectType): + x = String() + + def resolve_x(self, info): + return get_session(info.context) + + query = """ + query ReporterQuery { + x + } + """ + + schema = Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + assert result.data["x"] == session + + +def test_to_type_name(): + assert to_type_name("make_camel_case") == "MakeCamelCase" + assert to_type_name("AlreadyCamelCase") == "AlreadyCamelCase" + assert to_type_name("A_Snake_and_a_Camel") == "ASnakeAndACamel" + + +def test_to_enum_value_name(): + assert to_enum_value_name("make_enum_value_name") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("makeEnumValueName") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("HTTPStatus400Message") == "HTTP_STATUS400_MESSAGE" + assert to_enum_value_name("ALREADY_ENUM_VALUE_NAME") == "ALREADY_ENUM_VALUE_NAME" + + +# test deprecated sort enum utility functions + + +def test_sort_enum_for_model(): + with pytest.warns(DeprecationWarning): + enum = sort_enum_for_model(Pet) + assert isinstance(enum, type(Enum)) + assert str(enum) == "PetSortEnum" + for col in sa.inspect(Pet).columns: + assert hasattr(enum, col.name + "_asc") + assert hasattr(enum, col.name + "_desc") + + +def test_sort_enum_for_model_custom_naming(): + with pytest.warns(DeprecationWarning): + enum = sort_enum_for_model( + Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D") + ) + assert str(enum) == "Foo" + for col in sa.inspect(Pet).columns: + assert hasattr(enum, col.name.upper() + "A") + assert hasattr(enum, col.name.upper() + "D") + + +def test_enum_cache(): + with pytest.warns(DeprecationWarning): + assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) + + +def test_sort_argument_for_model(): + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(Pet) + + assert isinstance(arg.type, List) + assert arg.default_value == [Pet.id.name + "_asc"] + with pytest.warns(DeprecationWarning): + assert arg.type.of_type is sort_enum_for_model(Pet) + + +def test_sort_argument_for_model_no_default(): + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(Pet, False) + + assert arg.default_value is None + + +def test_sort_argument_for_model_multiple_pk(): + class MultiplePK(Base): + foo = sa.Column(sa.Integer, primary_key=True) + bar = sa.Column(sa.Integer, primary_key=True) + __tablename__ = "MultiplePK" + + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(MultiplePK) + assert set(arg.default_value) == set( + (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") + ) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/graphene_sqlalchemy/types.py new file mode 100644 index 00000000..9598ccd3 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/types.py @@ -0,0 +1,446 @@ +from collections import OrderedDict +from typing import Type, Tuple, Mapping, Callable + +import sqlalchemy +from graphene import Field, InputObjectType, Dynamic, Argument, Mutation, ID +from graphene.relay import Connection, Node +from graphene.types.objecttype import ObjectType, ObjectTypeOptions +from graphene.types.structures import Structure +from graphene.types.utils import yank_fields_from_attrs +from graphene.utils.str_converters import to_snake_case +from graphql import ResolveInfo +from sqlalchemy.ext.declarative import DeclarativeMeta +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.inspection import inspect as sqlalchemyinspect +from sqlalchemy.orm.exc import NoResultFound + +from .converter import (convert_sqlalchemy_column, + convert_sqlalchemy_composite, + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship) +from .enums import (enum_for_field, sort_argument_for_object_type, + sort_enum_for_object_type) +from .fields import SQLAlchemyFilteredConnectionField +from .fields import default_connection_field_factory +from .interfaces import SQLAlchemyInterface +from .registry import Registry, get_global_registry +from .utils import get_query, is_mapped_class, is_mapped_instance +from .utils import pluralize_name + + +def construct_fields( + obj_type, model, registry, only_fields, exclude_fields, connection_field_factory +): + inspected_model = sqlalchemyinspect(model) + + fields = OrderedDict() + + for name, column in inspected_model.columns.items(): + is_not_in_only = only_fields and name not in only_fields + # is_already_created = name in options.fields + is_excluded = name in exclude_fields # or is_already_created + if is_not_in_only or is_excluded: + # We skip this field if we specify only_fields and is not + # in there. Or when we exclude this field in exclude_fields + continue + converted_column = convert_sqlalchemy_column(column, registry) + registry.register_orm_field(obj_type, name, column) + fields[name] = converted_column + + for name, composite in inspected_model.composites.items(): + is_not_in_only = only_fields and name not in only_fields + # is_already_created = name in options.fields + is_excluded = name in exclude_fields # or is_already_created + if is_not_in_only or is_excluded: + # We skip this field if we specify only_fields and is not + # in there. Or when we exclude this field in exclude_fields + continue + converted_composite = convert_sqlalchemy_composite(composite, registry) + registry.register_orm_field(obj_type, name, composite) + fields[name] = converted_composite + + for hybrid_item in inspected_model.all_orm_descriptors: + + if type(hybrid_item) == hybrid_property: + name = hybrid_item.__name__ + + is_not_in_only = only_fields and name not in only_fields + # is_already_created = name in options.fields + is_excluded = name in exclude_fields # or is_already_created + + if is_not_in_only or is_excluded: + # We skip this field if we specify only_fields and is not + # in there. Or when we exclude this field in exclude_fields + continue + + converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) + registry.register_orm_field(obj_type, name, hybrid_item) + fields[name] = converted_hybrid_property + + # Get all the columns for the relationships on the model + for relationship in inspected_model.relationships: + is_not_in_only = only_fields and relationship.key not in only_fields + # is_already_created = relationship.key in options.fields + is_excluded = relationship.key in exclude_fields # or is_already_created + if is_not_in_only or is_excluded: + # We skip this field if we specify only_fields and is not + # in there. Or when we exclude this field in exclude_fields + continue + converted_relationship = convert_sqlalchemy_relationship( + relationship, registry, connection_field_factory + ) + name = relationship.key + registry.register_orm_field(obj_type, name, relationship) + fields[name] = converted_relationship + + return fields + + +class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): + model = None # type: sqlalchemy.Model + registry = None # type: sqlalchemy.Registry + connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] + id = None # type: str + + +class SQLAlchemyObjectType(ObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + connection_field_factory=default_connection_field_factory, + _meta=None, + **options + ): + assert is_mapped_class(model), ( + "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".' + ).format(cls.__name__, model) + + if not registry: + registry = get_global_registry() + + assert isinstance(registry, Registry), ( + "The attribute registry in {} needs to be an instance of " + 'Registry, received "{}".' + ).format(cls.__name__, registry) + + sqla_fields = yank_fields_from_attrs( + construct_fields( + obj_type=cls, + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + connection_field_factory=connection_field_factory, + ), + _as=Field, + ) + + if use_connection is None and interfaces: + use_connection = any( + (issubclass(interface, Node) for interface in interfaces) + ) + + if use_connection and not connection: + # We create the connection automatically + if not connection_class: + connection_class = Connection + + connection = connection_class.create_type( + "{}Connection".format(cls.__name__), node=cls + ) + + if connection is not None: + assert issubclass(connection, Connection), ( + "The connection must be a Connection. Received {}" + ).format(connection.__name__) + + if not _meta: + _meta = SQLAlchemyObjectTypeOptions(cls) + + _meta.model = model + _meta.registry = registry + + if _meta.fields: + _meta.fields.update(sqla_fields) + else: + _meta.fields = sqla_fields + + _meta.connection = connection + _meta.id = id or "id" + + super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( + _meta=_meta, interfaces=interfaces, **options + ) + + if not skip_registry: + registry.register(cls) + + @classmethod + def is_type_of(cls, root, info): + if isinstance(root, cls): + return True + if not is_mapped_instance(root): + raise Exception(('Received incompatible instance "{}".').format(root)) + return isinstance(root, cls._meta.model) + + @classmethod + def get_query(cls, info): + model = cls._meta.model + return get_query(model, info.context) + + @classmethod + def get_node(cls, info, id): + try: + return cls.get_query(info).get(id) + except NoResultFound: + return None + + def resolve_id(self, info): + # graphene_type = info.parent_type.graphene_type + keys = self.__mapper__.primary_key_from_instance(self) + return tuple(keys) if len(keys) > 1 else keys[0] + + @classmethod + def enum_for_field(cls, field_name): + return enum_for_field(cls, field_name) + + sort_enum = classmethod(sort_enum_for_object_type) + + sort_argument = classmethod(sort_argument_for_object_type) + + +class SQLAlchemyInputObjectType(InputObjectType): + @classmethod + def __init_subclass_with_meta__( + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + connection_field_factory=default_connection_field_factory, + _meta=None, + **options + ): + autoexclude = [] + + # always pull ids out to a separate argument + for col in sqlalchemy.inspect(model).columns: + if ((col.primary_key and col.autoincrement) or + (isinstance(col.type, sqlalchemy.types.TIMESTAMP) and + col.server_default is not None)): + autoexclude.append(col.name) + + if not registry: + registry = get_global_registry() + sqla_fields = yank_fields_from_attrs( + construct_fields(model, registry, only_fields, exclude_fields + tuple(autoexclude), connection_field_factory), + _as=Field, + ) + # create accessor for model to be retrieved for querying + cls.sqla_model = model + if use_connection is None and interfaces: + use_connection = any( + (issubclass(interface, Node) for interface in interfaces) + ) + + if use_connection and not connection: + # We create the connection automatically + if not connection_class: + connection_class = Connection + + connection = connection_class.create_type( + "{}Connection".format(cls.__name__), node=cls + ) + + if connection is not None: + assert issubclass(connection, Connection), ( + "The connection must be a Connection. Received {}" + ).format(connection.__name__) + + for key, value in sqla_fields.items(): + if not (isinstance(value, Dynamic) or hasattr(cls, key)): + setattr(cls, key, value) + + super(SQLAlchemyInputObjectType, cls).__init_subclass_with_meta__(**options) + + +class SQLAlchemyAutoSchemaFactory(ObjectType): + + @staticmethod + def set_fields_and_attrs(klazz: Type[ObjectType], node_model: Type[SQLAlchemyInterface], field_dict: Mapping[str, Field]): + _name = to_snake_case(node_model.__name__) + field_dict[f'all_{(pluralize_name(_name))}'] = SQLAlchemyFilteredConnectionField(node_model) + field_dict[_name] = node_model.Field() + setattr(klazz, _name, node_model.Field()) + setattr(klazz, "all_{}".format(pluralize_name(_name)), SQLAlchemyFilteredConnectionField(node_model)) + + @classmethod + def __init_subclass_with_meta__( + cls, + interfaces: Tuple[Type[SQLAlchemyInterface]] = (), + models: Tuple[Type[DeclarativeMeta]] = (), + excluded_models: Tuple[Type[DeclarativeMeta]] = (), + node_interface: Type[Node] = Type[Node], + default_resolver: ResolveInfo = None, + _meta=None, + **options + ): + if not _meta: + _meta = ObjectTypeOptions(cls) + + fields = OrderedDict() + + for interface in interfaces: + if issubclass(interface, SQLAlchemyInterface): + SQLAlchemyAutoSchemaFactory.set_fields_and_attrs(cls, interface, fields) + for model in excluded_models: + if model in models: + models = models[:models.index(model)] + models[models.index(model) + 1:] + possible_types = () + for model in models: + model_name = model.__name__ + _model_name = to_snake_case(model.__name__) + + if hasattr(cls, _model_name): + continue + if hasattr(cls, "all_{}".format(pluralize_name(_model_name))): + continue + for iface in interfaces: + if issubclass(model, iface._meta.model): + model_interface = (iface,) + break + else: + model_interface = (node_interface,) + + _node_class = type(model_name, + (SQLAlchemyObjectType,), + {"Meta": {"model": model, "interfaces": model_interface, "only_fields": []}}) + fields["all_{}".format(pluralize_name(_model_name))] = SQLAlchemyFilteredConnectionField(_node_class) + setattr(cls, "all_{}".format(pluralize_name(_model_name)), SQLAlchemyFilteredConnectionField(_node_class)) + fields[_model_name] = node_interface.Field(_node_class) + setattr(cls, _model_name, node_interface.Field(_node_class)) + possible_types += (_node_class,) + if _meta.fields: + _meta.fields.update(fields) + else: + _meta.fields = fields + _meta.schema_types = possible_types + + super(SQLAlchemyAutoSchemaFactory, cls).__init_subclass_with_meta__(_meta=_meta, default_resolver=default_resolver, **options) + + @classmethod + def resolve_with_filters(cls, info: ResolveInfo, model: Type[SQLAlchemyObjectType], **kwargs): + query = model.get_query(info) + for filter_name, filter_value in kwargs.items(): + model_filter_column = getattr(model._meta.model, filter_name, None) + if not model_filter_column: + continue + if isinstance(filter_value, SQLAlchemyInputObjectType): + filter_model = filter_value.sqla_model + q = SQLAlchemyFilteredConnectionField.get_query(filter_model, info, sort=None, **kwargs) + # noinspection PyArgumentList + query = query.filter(model_filter_column == q.filter_by(**filter_value)) + else: + query = query.filter(model_filter_column == filter_value) + return query + + +class SQLAlchemyMutationOptions(ObjectTypeOptions): + model: DeclarativeMeta = None + create: bool = False + delete: bool = False + arguments: Mapping[str, Argument] = None + output: Type[ObjectType] = None + resolver: Callable = None + + +class SQLAlchemyMutation(Mutation): + @classmethod + def __init_subclass_with_meta__(cls, model=None, create=False, + delete=False, registry=None, + arguments=None, only_fields=(), + structure: Type[Structure] = None, + exclude_fields=(), **options): + meta = SQLAlchemyMutationOptions(cls) + meta.create = create + meta.model = model + meta.delete = delete + + if arguments is None and not hasattr(cls, "Arguments"): + arguments = {} + # don't include id argument on create + if not meta.create: + arguments['id'] = ID(required=True) + + # don't include input argument on delete + if not meta.delete: + inputMeta = type('Meta', (object,), { + 'model': model, + 'exclude_fields': exclude_fields, + 'only_fields': only_fields + }) + inputType = type(cls.__name__ + 'Input', + (SQLAlchemyInputObjectType,), + {'Meta': inputMeta}) + arguments = {'input': inputType(required=True)} + if not registry: + registry = get_global_registry() + output_type: ObjectType = registry.get_type_for_model(model) + if structure: + output_type = structure(output_type) + super(SQLAlchemyMutation, cls).__init_subclass_with_meta__(_meta=meta, output=output_type, arguments=arguments, **options) + + @classmethod + def mutate(cls, info, **kwargs): + session = get_session(info.context) + with session.no_autoflush: + meta = cls._meta + + if meta.create: + model = meta.model(**kwargs['input']) + session.add(model) + else: + model = session.query(meta.model).filter(meta.model.id == + kwargs['id']).first() + if meta.delete: + session.delete(model) + else: + def setModelAttributes(model, attrs): + relationships = model.__mapper__.relationships + for key, value in attrs.items(): + if key in relationships: + if getattr(model, key) is None: + # instantiate class of the same type as + # the relationship target + setattr(model, key, + relationships[key].mapper.entity()) + setModelAttributes(getattr(model, key), value) + else: + setattr(model, key, value) + + setModelAttributes(model, kwargs['input']) + session.commit() + + return model + + @classmethod + def Field(cls, *args, **kwargs): + return Field(cls._meta.output, + args=cls._meta.arguments, + resolver=cls._meta.resolver) diff --git a/graphene_sqlalchemy/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/graphene_sqlalchemy/utils.py new file mode 100644 index 00000000..3083c3f9 --- /dev/null +++ b/graphene_sqlalchemy/graphene_sqlalchemy/utils.py @@ -0,0 +1,199 @@ +import re +import warnings +from collections import OrderedDict + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import class_mapper, object_mapper +from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError + + +def get_session(context): + return context.get("session") + + +def get_query(model, context): + query = getattr(model, "query", None) + if not query: + session = get_session(context) + if not session: + raise Exception( + "A query in the model Base or a session in the schema is required for querying.\n" + "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" + ) + query = session.query(model) + return query + + +def is_mapped_class(cls): + try: + class_mapper(cls) + except (ArgumentError, UnmappedClassError): + return False + else: + return True + + +def is_mapped_instance(cls): + try: + object_mapper(cls) + except (ArgumentError, UnmappedInstanceError): + return False + else: + return True + + +def to_type_name(name): + """Convert the given name to a GraphQL type name.""" + return "".join(part[:1].upper() + part[1:] for part in name.split("_")) + + +_re_enum_value_name_1 = re.compile("(.)([A-Z][a-z]+)") +_re_enum_value_name_2 = re.compile("([a-z0-9])([A-Z])") + + +def to_enum_value_name(name): + """Convert the given name to a GraphQL enum value name.""" + return _re_enum_value_name_2.sub( + r"\1_\2", _re_enum_value_name_1.sub(r"\1_\2", name) + ).upper() + + +class EnumValue(str): + """String that has an additional value attached. + + This is used to attach SQLAlchemy model columns to Enum symbols. + """ + + def __new__(cls, s, value): + return super(EnumValue, cls).__new__(cls, s) + + def __init__(self, _s, value): + super(EnumValue, self).__init__() + self.value = value + + +def _deprecated_default_symbol_name(column_name, sort_asc): + return column_name + ("_asc" if sort_asc else "_desc") + + +# unfortunately, we cannot use lru_cache because we still support Python 2 +_deprecated_object_type_cache = {} + + +def _deprecated_object_type_for_model(cls, name): + try: + return _deprecated_object_type_cache[cls, name] + except KeyError: + from .types import SQLAlchemyObjectType + + obj_type_name = name or cls.__name__ + + class ObjType(SQLAlchemyObjectType): + class Meta: + name = obj_type_name + model = cls + + _deprecated_object_type_cache[cls, name] = ObjType + return ObjType + + +def sort_enum_for_model(cls, name=None, symbol_name=None): + """Get a Graphene Enum for sorting the given model class. + + This is deprecated, please use object_type.sort_enum() instead. + """ + warnings.warn( + "sort_enum_for_model() is deprecated; use object_type.sort_enum() instead.", + DeprecationWarning, + stacklevel=2, + ) + + from .enums import sort_enum_for_object_type + + return sort_enum_for_object_type( + _deprecated_object_type_for_model(cls, name), + name, + get_symbol_name=symbol_name or _deprecated_default_symbol_name, + ) + + +def pluralize_name(name): + s1 = re.sub("y$", "ie", name) + return "{}s".format(s1) + + +def sort_argument_for_model(cls, has_default=True): + """Get a Graphene Argument for sorting the given model class. + + This is deprecated, please use object_type.sort_argument() instead. + """ + warnings.warn( + "sort_argument_for_model() is deprecated;" + " use object_type.sort_argument() instead.", + DeprecationWarning, + stacklevel=2, + ) + + from graphene import Argument, List + from .enums import sort_enum_for_object_type + + enum = sort_enum_for_object_type( + _deprecated_object_type_for_model(cls, None), + get_symbol_name=_deprecated_default_symbol_name, + ) + if not has_default: + enum.default = None + + return Argument(List(enum), default_value=enum.default) + + +argument_cache = {} +field_cache = {} + + +class FilterArgument: + pass + + +class FilterField: + pass + + +def create_filter_field(column): + from graphene import InputObjectType, Field + from .converter import convert_sqlalchemy_type + + graphene_type = convert_sqlalchemy_type(column.type, column) + if graphene_type.__class__ == Field: + return None + + name = "{}Filter".format(str(graphene_type.__class__)) + if name in field_cache: + return Field(field_cache[name]) + + fields = OrderedDict((key, Field(graphene_type.__class__)) + for key in ["equal", "notEqual", "lessThan", "greaterThan", "like"]) + field_class: InputObjectType = type(name, (FilterField, InputObjectType), {}) + field_class._meta.fields.update(fields) + + field_cache[name] = field_class + return Field(field_class) + + +def create_filter_argument(cls): + from graphene import Argument, InputObjectType + from sqlalchemy import inspect + name = "{}Filter".format(cls.__name__) + if name in argument_cache: + return Argument(argument_cache[name]) + import re + + NAME_PATTERN = r"^[_a-zA-Z][_a-zA-Z0-9]*$" + COMPILED_NAME_PATTERN = re.compile(NAME_PATTERN) + fields = OrderedDict((column.name, field) + for column, field in [(column, create_filter_field(column)) + for column in inspect(cls).columns.values()] if field and COMPILED_NAME_PATTERN.match(column.name)) + argument_class: InputObjectType = type(name, (FilterArgument, InputObjectType), {}) + argument_class._meta.fields.update(fields) + argument_cache[name] = argument_class + return Argument(argument_class)