Skip to content

Commit

Permalink
Remove logic related to keygetter.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmcarp committed Jun 21, 2015
1 parent ec04241 commit 7f36ba3
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 43 deletions.
1 change: 0 additions & 1 deletion marshmallow_sqlalchemy/__init__.py
Expand Up @@ -9,7 +9,6 @@
from .convert import (
ModelConverter,
fields_for_model,
get_pk_from_identity,
property2field,
column2field,
field_for,
Expand Down
31 changes: 7 additions & 24 deletions marshmallow_sqlalchemy/convert.py
Expand Up @@ -4,25 +4,12 @@

import marshmallow as ma
from marshmallow import validate, fields
from marshmallow.compat import text_type
from sqlalchemy.dialects import postgresql, mysql
from sqlalchemy.orm.util import identity_key
import sqlalchemy as sa

from .exceptions import ModelConversionError
from .fields import Related

def get_pk_from_identity(obj):
"""Get primary key for `obj`. If `obj` has a compound primary key,
return a string of keys separated by ``":"``. This is the default keygetter for
used by `ModelSchema <marshmallow_sqlalchemy.ModelSchema>`.
"""
_, key = identity_key(instance=obj)
if len(key) == 1:
return key[0]
else: # Compund primary key
return ':'.join(text_type(x) for x in key)

def _is_field(value):
return (
isinstance(value, type) and
Expand Down Expand Up @@ -64,26 +51,24 @@ class ModelConverter(object):
'ONETOMANY': True,
}

def fields_for_model(self, model, session=None, include_fk=False, keygetter=None, fields=None):
def fields_for_model(self, model, session=None, include_fk=False, fields=None):
result = {}
for prop in model.__mapper__.iterate_properties:
if fields and prop.key not in fields:
continue
if hasattr(prop, 'columns'):
if not include_fk and prop.columns[0].foreign_keys:
continue
field = self.property2field(prop, session=session, keygetter=keygetter)
field = self.property2field(prop, session=session)
if field:
result[prop.key] = field
return result

def property2field(self, prop, session=None, keygetter=None, instance=True, **kwargs):
def property2field(self, prop, session=None, instance=True, **kwargs):
field_class = self._get_field_class_for_property(prop)
if not instance:
return field_class
field_kwargs = self._get_field_kwargs_for_property(
prop, session=session, keygetter=keygetter
)
field_kwargs = self._get_field_kwargs_for_property(prop, session=session)
field_kwargs.update(kwargs)
ret = field_class(**field_kwargs)
if hasattr(prop, 'direction') and self.DIRECTION_MAPPING[prop.direction.name]:
Expand Down Expand Up @@ -133,13 +118,13 @@ def _get_field_class_for_property(self, prop):
field_cls = self._get_field_class_for_column(column)
return field_cls

def _get_field_kwargs_for_property(self, prop, session=None, keygetter=None):
def _get_field_kwargs_for_property(self, prop, session=None):
kwargs = self.get_base_kwargs()
if hasattr(prop, 'columns'):
column = prop.columns[0]
self._add_column_kwargs(kwargs, column)
if hasattr(prop, 'direction'): # Relationship property
self._add_relationship_kwargs(kwargs, prop, session=session, keygetter=keygetter)
self._add_relationship_kwargs(kwargs, prop, session=session)
if getattr(prop, 'doc', None): # Useful for documentation generation
kwargs['description'] = prop.doc
return kwargs
Expand All @@ -165,7 +150,7 @@ def _add_column_kwargs(self, kwargs, column):
if getattr(column, 'primary_key', False):
kwargs['dump_only'] = True

def _add_relationship_kwargs(self, kwargs, prop, session, keygetter=None):
def _add_relationship_kwargs(self, kwargs, prop, session):
"""Add keyword arguments to kwargs (in-place) based on the passed in
relationship `Property`.
"""
Expand Down Expand Up @@ -201,8 +186,6 @@ def get_base_kwargs(self):
:param Property prop: SQLAlchemy Property.
:param Session session: SQLALchemy session.
:param keygetter: See `marshmallow.fields.QuerySelect` for documenation on the
keygetter parameter.
:param bool instance: If `True`, return `Field` instance, computing relevant kwargs
from the given property. If `False`, return the `Field` class.
:param kwargs: Additional keyword arguments to pass to the field constructor.
Expand Down
16 changes: 2 additions & 14 deletions marshmallow_sqlalchemy/schema.py
Expand Up @@ -2,9 +2,9 @@
import inspect

import marshmallow as ma
from marshmallow.compat import with_metaclass, PY2
from marshmallow.compat import with_metaclass

from .convert import get_pk_from_identity, ModelConverter
from .convert import ModelConverter


class SchemaOpts(ma.SchemaOpts):
Expand All @@ -13,10 +13,6 @@ class SchemaOpts(ma.SchemaOpts):
- ``model``: The SQLAlchemy model to generate the `Schema` from (required).
- ``sqla_session``: SQLAlchemy session (required).
- ``keygetter``: A `str` or function. Can be a callable or a string.
In the former case, it must be a one-argument callable which returns a unique comparable
key. In the latter case, the string specifies the name of
an attribute of the ORM-mapped object.
- ``model_converter``: `ModelConverter` class to use for converting the SQLAlchemy model to
marshmallow fields.
"""
Expand All @@ -27,7 +23,6 @@ def __init__(self, meta):
self.sqla_session = getattr(meta, 'sqla_session', None)
if self.model and not self.sqla_session:
raise ValueError('SQLAlchemyModelSchema requires the "sqla_session" class Meta option')
self.keygetter = getattr(meta, 'keygetter', get_pk_from_identity)
self.model_converter = getattr(meta, 'model_converter', ModelConverter)

class SchemaMeta(ma.schema.SchemaMeta):
Expand All @@ -43,19 +38,12 @@ def get_declared_fields(mcs, klass, *args, **kwargs):
# inheriting from base classes
for base in inspect.getmro(klass):
opts = klass.opts
# In Python 2, Meta.keygetter will be an unbound method,
# so we need to get the unbound function
if PY2 and inspect.ismethod(opts.keygetter):
keygetter = opts.keygetter.im_func
else:
keygetter = opts.keygetter
if opts.model:
Converter = opts.model_converter
converter = Converter()
declared_fields = converter.fields_for_model(
opts.model,
opts.sqla_session,
keygetter=keygetter,
fields=opts.fields,
)
break
Expand Down
4 changes: 0 additions & 4 deletions tests/test_marshmallow_sqlalchemy.py
Expand Up @@ -104,9 +104,6 @@ def __init__(self):
self.Student = Student
return _models()

def hyperlink_keygetter(obj):
return obj.url

@pytest.fixture()
def schemas(models, session):
class CourseSchema(ModelSchema):
Expand All @@ -128,7 +125,6 @@ class HyperlinkStudentSchema(ModelSchema):
class Meta:
model = models.Student
sqla_session = session
keygetter = hyperlink_keygetter

# Again, so we can use dot-notation
class _schemas(object):
Expand Down

0 comments on commit 7f36ba3

Please sign in to comment.