Skip to content

Commit

Permalink
Feature/synced collection/optimize jsondict validation (#508)
Browse files Browse the repository at this point in the history
Define a single validator for JSONAttrDict classes that combines the logic of other validators while reducing collection traversal costs. Also switch from converting numpy arrays to just bypassing the resolver's cache for them.

* Use single separate validator for state points for performance.

* Remove preprocessor from type resolver and instead use a blocklist that prevents caching data for certain types.

* Reorder resolvers to optimize performance.

* Make sure not to include strings as sequences.

* Move state point validator to collection_json and use for all JSONAttrDict types.

* Make sure to also check complex types.

* Add back missing period lost during stash merge.

* Address review comments.
  • Loading branch information
vyasr committed Feb 19, 2021
1 parent 48863c4 commit f0537c4
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 45 deletions.
4 changes: 3 additions & 1 deletion signac/contrib/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from deprecation import deprecated

from ..core.h5store import H5StoreManager
from ..errors import KeyTypeError
from ..sync import sync_jobs
from ..synced_collections.backends.collection_json import (
BufferedJSONAttrDict,
JSONAttrDict,
json_attr_dict_validator,
)
from ..synced_collections.errors import KeyTypeError
from ..version import __version__
from .errors import DestinationExistsError, JobsCorruptedError
from .hashing import calc_id
Expand All @@ -46,6 +47,7 @@ class _StatePointDict(JSONAttrDict):
"""

_PROTECTED_KEYS: Tuple[str, ...] = JSONAttrDict._PROTECTED_KEYS + ("_jobs",)
_all_validators = (json_attr_dict_validator,)

def __init__(
self,
Expand Down
108 changes: 104 additions & 4 deletions signac/synced_collections/backends/collection_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,25 @@
import os
import uuid
import warnings
from typing import Callable, Sequence, Tuple
from collections.abc import Mapping, Sequence
from typing import Callable
from typing import Sequence as Sequence_t
from typing import Tuple

from .. import SyncedCollection, SyncedDict, SyncedList
from ..buffers.memory_buffered_collection import SharedMemoryFileBufferedCollection
from ..buffers.serialized_file_buffered_collection import (
SerializedFileBufferedCollection,
)
from ..data_types.attr_dict import AttrDict
from ..errors import KeyTypeError
from ..utils import SyncedCollectionJSONEncoder
from ..errors import InvalidKeyError, KeyTypeError
from ..numpy_utils import (
_is_atleast_1d_numpy_array,
_is_complex,
_is_numpy_scalar,
_numpy_cache_blocklist,
)
from ..utils import AbstractTypeResolver, SyncedCollectionJSONEncoder
from ..validators import json_format_validator, no_dot_in_key

"""
Expand Down Expand Up @@ -77,6 +86,88 @@ def _convert_key_to_str(data):
_convert_key_to_str(value)


_json_attr_dict_validator_type_resolver = AbstractTypeResolver(
{
# We identify >0d numpy arrays as sequences for validation purposes.
"SEQUENCE": lambda obj: (isinstance(obj, Sequence) and not isinstance(obj, str))
or _is_atleast_1d_numpy_array(obj),
"NUMPY": lambda obj: _is_numpy_scalar(obj),
"BASE": lambda obj: isinstance(obj, (str, int, float, bool, type(None))),
"MAPPING": lambda obj: isinstance(obj, Mapping),
},
cache_blocklist=_numpy_cache_blocklist,
)


def json_attr_dict_validator(data):
"""Validate data for JSONAttrDict.
This validator combines the logic from the following validators into one to
make validation more efficient:
This validator combines the following logic:
- JSON format validation
- Ensuring no dots are present in string keys
- Converting non-str keys to strings. This is a backwards compatibility
layer that will be removed in signac 2.0.
Parameters
----------
data
Data to validate.
Raises
------
KeyTypeError
If key data type is not supported.
TypeError
If the data type of ``data`` is not supported.
"""
switch_type = _json_attr_dict_validator_type_resolver.get_type(data)

if switch_type == "BASE":
return
elif switch_type == "MAPPING":
# Explicitly call `list(keys)` to get a fixed list of keys to avoid
# running into issues with iterating over a DictKeys view while
# modifying the dict at the same time. Inside the loop, we:
# 1) validate the key, converting to string if necessary
# 2) pop and validate the value
# 3) reassign the value to the (possibly converted) key
for key in list(data):
json_attr_dict_validator(data[key])
if isinstance(key, str):
if "." in key:
raise InvalidKeyError(
f"Mapping keys may not contain dots ('.'): {key}."
)
elif isinstance(key, (int, bool, type(None))):
# TODO: Remove this branch in signac 2.0.
warnings.warn(
f"Use of {type(key).__name__} as key is deprecated "
"and will be removed in version 2.0.",
DeprecationWarning,
)
data[str(key)] = data.pop(key)
else:
raise KeyTypeError(
f"Mapping keys must be str, int, bool or None, not {type(key).__name__}."
)
elif switch_type == "SEQUENCE":
for value in data:
json_attr_dict_validator(value)
elif switch_type == "NUMPY":
if _is_numpy_scalar(data.item()):
raise TypeError("NumPy extended precision types are not JSON serializable.")
elif _is_complex(data):
raise TypeError("Complex numbers are not JSON serializable.")
else:
raise TypeError(
f"Object of type {type(data).__name__} is not JSON serializable."
)


"""
Here we define the main JSONCollection class that encapsulates most of the
logic for reading from and writing to JSON files. The remaining classes in
Expand Down Expand Up @@ -127,7 +218,7 @@ class JSONCollection(SyncedCollection):
# in the future, however, the _convert_key_to_str validator will be removed in
# signac 2.0 so this is OK (that validator is modifying the data in place,
# which is unsupported behavior that will be removed in signac 2.0 as well).
_validators: Sequence[Callable] = (_convert_key_to_str, json_format_validator)
_validators: Sequence_t[Callable] = (_convert_key_to_str, json_format_validator)

def __init__(self, filename=None, write_concern=False, *args, **kwargs):
# The `_filename` attribute _must_ be defined prior to calling the
Expand Down Expand Up @@ -533,7 +624,10 @@ class JSONAttrDict(JSONDict, AttrDict):
"""

_backend = __name__ + ".attr" # type: ignore
# Define the validators in case subclasses want to inherit the correct
# behavior, but define _all_validators for performance of this class.
_validators = (no_dot_in_key,)
_all_validators = (json_attr_dict_validator,)


class JSONAttrList(JSONList):
Expand All @@ -546,7 +640,10 @@ class BufferedJSONAttrDict(BufferedJSONDict, AttrDict):
"""A buffered :class:`JSONAttrDict`."""

_backend = __name__ + ".buffered_attr" # type: ignore
# Define the validators in case subclasses want to inherit the correct
# behavior, but define _all_validators for performance of this class.
_validators = (no_dot_in_key,)
_all_validators = (json_attr_dict_validator,)


class BufferedJSONAttrList(BufferedJSONList):
Expand All @@ -559,7 +656,10 @@ class MemoryBufferedJSONAttrDict(MemoryBufferedJSONDict, AttrDict):
"""A buffered :class:`JSONAttrDict`."""

_backend = __name__ + ".memory_buffered_attr" # type: ignore
# Define the validators in case subclasses want to inherit the correct
# behavior, but define _all_validators for performance of this class.
_validators = (no_dot_in_key,)
_all_validators = (json_attr_dict_validator,)


class MemoryBufferedJSONAttrList(MemoryBufferedJSONList):
Expand Down
8 changes: 6 additions & 2 deletions signac/synced_collections/data_types/synced_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from collections.abc import MutableSequence, Sequence

from ..numpy_utils import _convert_numpy, _is_atleast_1d_numpy_array
from ..numpy_utils import (
_convert_numpy,
_is_atleast_1d_numpy_array,
_numpy_cache_blocklist,
)
from ..utils import AbstractTypeResolver
from .synced_collection import SyncedCollection, _sc_resolver

Expand All @@ -22,7 +26,7 @@
or _is_atleast_1d_numpy_array(obj)
),
},
_convert_numpy,
cache_blocklist=_numpy_cache_blocklist,
)


Expand Down
20 changes: 19 additions & 1 deletion signac/synced_collections/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import numpy

NUMPY = True
_numpy_cache_blocklist = (numpy.ndarray,)
except ImportError:
NUMPY = False
_numpy_cache_blocklist = None # type: ignore


class NumpyConversionWarning(UserWarning):
Expand Down Expand Up @@ -70,4 +72,20 @@ def _is_numpy_scalar(data):
bool
Whether or not the input is a numpy scalar type.
"""
return NUMPY and (isinstance(data, (numpy.number, numpy.bool_)))
return NUMPY and (
(isinstance(data, (numpy.number, numpy.bool_)))
or (isinstance(data, numpy.ndarray) and data.ndim == 0)
)


def _is_complex(data):
"""Check if an object is complex.
This function works for both numpy raw Python data types.
Returns
-------
bool
Whether or not the input is a complex number.
"""
return (NUMPY and numpy.iscomplex(data).any()) or (isinstance(data, complex))
40 changes: 24 additions & 16 deletions signac/synced_collections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,31 @@ class AbstractTypeResolver:
of types that must be resolved and a way to identify each of these (which
may be expensive), it maintains a local cache of all instances of a given
type that have previously been observed. This reduces the cost of type checking
to a simple dict lookup, except for the first time a new type is observed.
to a simple ``dict`` lookup, except for the first time a new type is observed.
Parameters
----------
abstract_type_identifiers : collections.abc.Mapping
abstract_type_identifiers : Mapping
A mapping from a string identifier for a group of types (e.g. ``"MAPPING"``)
to a callable that can be used to identify that type. Due to insertion order
guarantees of dictionaries in Python>=3.6 (officially 3.7), it is beneficial
guarantees of dictionaries in Python>=3.6 (officially 3.7), it may be beneficial
to order this dictionary with the most frequently occuring types first.
However, unless users have many different concrete types implementing
the same abstract interface (e.g. many Mapping types identified via
``isinstance(obj, Mapping)``), any performance gain should be negligible
since the callables will only be executed once per type.
cache_blocklist : Sequence, optional
A sequence of string identifiers from ``abstract_type_identifiers`` that
should not be cached. If there are cases where objects of the same type
would be classified into separate groups based on the callables in
``abstract_type_identifiers``, this argument allows users to specify that
this type should not be cached. This argument should be used sparingly
because performance will quickly degrade if many calls to
:meth:`get_type` are with types that cannot be cached. The identifiers
(keys in ``abstract_type_identifiers``) corresponding to elements of the
blocklist should be placed first in the ``abstract_type_identifiers``
dictionary since they will never be cached and are therefore the most
likely callables to be used repeatedly (Default value = None).
Attributes
----------
Expand All @@ -43,19 +59,13 @@ class AbstractTypeResolver:
type_map : Dict[Type, str]
A mapping from concrete types to the corresponding named abstract type
from :attr:`~.abstract_type_identifiers`.
preprocessor : callable or None
An operation to perform on an object before type lookup. Providing this
callable can be used if input data types cannot be checked because
objects of a given type must be treated differently based on additional
criteria, in which case this function can be used to preprocess them and
convert them to a suitable type for type-checking.
"""

def __init__(self, abstract_type_identifiers, preprocessor=None):
def __init__(self, abstract_type_identifiers, cache_blocklist=None):
self.abstract_type_identifiers = abstract_type_identifiers
self.type_map = {}
self.preprocessor = preprocessor
self.cache_blocklist = cache_blocklist if cache_blocklist is not None else ()

def get_type(self, obj):
"""Get the type string corresponding to this data type.
Expand All @@ -73,19 +83,17 @@ def get_type(self, obj):
will return ``None``.
"""
if self.preprocessor is not None:
obj = self.preprocessor(obj)

obj_type = type(obj)
enum_type = None
try:
enum_type = self.type_map[obj_type]
except KeyError:
for data_type, id_func in self.abstract_type_identifiers.items():
if id_func(obj):
enum_type = self.type_map[obj_type] = data_type
enum_type = data_type
break
self.type_map[obj_type] = enum_type
if obj_type not in self.cache_blocklist:
self.type_map[obj_type] = enum_type

return enum_type

Expand Down
22 changes: 14 additions & 8 deletions signac/synced_collections/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
from collections.abc import Mapping, Sequence

from .errors import InvalidKeyError, KeyTypeError
from .numpy_utils import _convert_numpy, _is_atleast_1d_numpy_array, _is_numpy_scalar
from .numpy_utils import (
_is_atleast_1d_numpy_array,
_is_complex,
_is_numpy_scalar,
_numpy_cache_blocklist,
)
from .utils import AbstractTypeResolver

_no_dot_in_key_type_resolver = AbstractTypeResolver(
{
"MAPPING": lambda obj: isinstance(obj, Mapping),
"NON_STR_SEQUENCE": lambda obj: isinstance(obj, Sequence)
and not isinstance(obj, str),
"SEQUENCE": lambda obj: isinstance(obj, Sequence) and not isinstance(obj, str),
}
)

Expand Down Expand Up @@ -58,7 +62,7 @@ def no_dot_in_key(data):
f"Mapping keys must be str, int, bool or None, not {type(key).__name__}"
)
no_dot_in_key(value)
elif switch_type == "NON_STR_SEQUENCE":
elif switch_type == "SEQUENCE":
for value in data:
no_dot_in_key(value)

Expand Down Expand Up @@ -96,14 +100,14 @@ def require_string_key(data):

_json_format_validator_type_resolver = AbstractTypeResolver(
{
"BASE": lambda obj: isinstance(obj, (str, int, float, bool, type(None))),
"MAPPING": lambda obj: isinstance(obj, Mapping),
# We identify >0d numpy arrays as sequences for validation purposes.
"SEQUENCE": lambda obj: isinstance(obj, Sequence)
"SEQUENCE": lambda obj: (isinstance(obj, Sequence) and not isinstance(obj, str))
or _is_atleast_1d_numpy_array(obj),
"NUMPY": lambda obj: _is_numpy_scalar(obj),
"BASE": lambda obj: isinstance(obj, (str, int, float, bool, type(None))),
"MAPPING": lambda obj: isinstance(obj, Mapping),
},
preprocessor=_convert_numpy,
cache_blocklist=_numpy_cache_blocklist,
)


Expand Down Expand Up @@ -138,6 +142,8 @@ def json_format_validator(data):
elif switch_type == "NUMPY":
if _is_numpy_scalar(data.item()):
raise TypeError("NumPy extended precision types are not JSON serializable.")
elif _is_complex(data):
raise TypeError("Complex numbers are not JSON serializable.")
else:
raise TypeError(
f"Object of type {type(data).__name__} is not JSON serializable"
Expand Down
7 changes: 7 additions & 0 deletions tests/test_synced_collections/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_json_encoder(tmpdir):
assert json.dumps(synced_data, cls=SyncedCollectionJSONEncoder) == json_str_data

if NUMPY:
# Test both scalar and array numpy types since they could have
# different problems.
array = numpy.array(3)
with pytest.warns(NumpyConversionWarning):
synced_data["foo"] = array
assert isinstance(synced_data["foo"], int)

array = numpy.random.rand(3)
with pytest.warns(NumpyConversionWarning):
synced_data["foo"] = array
Expand Down
Loading

0 comments on commit f0537c4

Please sign in to comment.