Skip to content

Commit

Permalink
Merge a9482db into 3e0891a
Browse files Browse the repository at this point in the history
  • Loading branch information
smith-m committed Aug 23, 2018
2 parents 3e0891a + a9482db commit a5bebab
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 36 deletions.
28 changes: 21 additions & 7 deletions marshmallow_jsonschema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@

from .validation import handle_length, handle_one_of, handle_range


__all__ = (
'JSONSchema',
)


TYPE_MAP = {
dict: {
'type': 'object',
Expand Down Expand Up @@ -72,7 +70,6 @@
},
}


FIELD_VALIDATORS = {
validate.Length: handle_length,
validate.OneOf: handle_one_of,
Expand All @@ -91,6 +88,7 @@ def __init__(self, *args, **kwargs):
"""Setup internal cache of nested fields, to prevent recursion."""
self._nested_schema_classes = {}
self.nested = kwargs.pop('nested', False)
self.prefer_data_key = kwargs.pop('prefer_data_key', False)
super(JSONSchema, self).__init__(*args, **kwargs)

def _get_default_mapping(self, obj):
Expand All @@ -112,7 +110,7 @@ def get_properties(self, obj):

for field_name, field in sorted(obj.fields.items()):
schema = self._get_schema_for_field(obj, field)
properties[field.name] = schema
properties[self._get_property_name_for_field(field)] = schema

return properties

Expand All @@ -122,16 +120,18 @@ def get_required(self, obj):

for field_name, field in sorted(obj.fields.items()):
if field.required:
required.append(field.name)
required.append(
self._get_property_name_for_field(field)
)

return required or missing

def _from_python_type(self, obj, field, pytype):
"""Get schema definition from python type."""
json_schema = {
'title': field.attribute or field.name,
'title': field.attribute or self._get_property_name_for_field(
field),
}

for key, val in TYPE_MAP[pytype].items():
json_schema[key] = val

Expand Down Expand Up @@ -182,6 +182,20 @@ def _get_schema_for_field(self, obj, field):
)
return schema

def _get_property_name_for_field(self, field):
"""Get property name for field based on serialized object"""
name = field.name

if self.prefer_data_key:
# Handle change in load_from / dump_to between Marshmallow
# versions 2 and 3.
if marshmallow.__version__.split('.', 1)[0] >= '3':
name = field.data_key or name
else:
name = field.load_from or field.dump_to or name

return name

def _from_nested_schema(self, obj, field):
"""Support nested field."""
if isinstance(field.nested, basestring):
Expand Down
Loading

0 comments on commit a5bebab

Please sign in to comment.