Skip to content

Commit

Permalink
Fix/issue 168 (#205)
Browse files Browse the repository at this point in the history
* Implement unit tests catching issue #168.

* Refactor state point management in Job class.

And explicitly check key type in SyncedDict.

* Remove obsolete SPDict class from job module.

* Update changelog.

* Refactor the `Job.reset_statepoint()` function for clarity.

* Refactor 'reset_document' function into JSONDict class.

The reset function for the Job.document and Project.document is thereby
implemented as part of the JSONDict class.

* Convert mappings upon reset (JSONDict).

* Refactor Job constructor.
  • Loading branch information
csadorf committed Jul 14, 2019
1 parent 8556a04 commit 34ac45c
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 73 deletions.
3 changes: 3 additions & 0 deletions changelog.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ next
- Reduce the logging verbosity about a missing default host key in the configuration (#201).
- Add ``read_json()`` and ``to_json()`` methods to Collection class (#104).
- Fix issue with incorrect detection of dict-like files managed with the ``DictManager`` class (e.g. ``job.stores``) (#203).
- Fix issue causing a failure of the automatic conversion of valid key types (#168, #205).
- Improve the 'dots in keys' error message to make it easier to fix related issues (#170, #205).


[1.1.0] -- 2019-05-19
---------------------
Expand Down
50 changes: 8 additions & 42 deletions signac/contrib/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import errno
import logging
import shutil
import uuid

from ..common import six
from ..core import json
Expand All @@ -16,10 +15,7 @@
from .utility import _mkdir_p
from .errors import DestinationExistsError, JobsCorruptedError
from ..sync import sync_jobs
if six.PY2:
from collections import Mapping
else:
from collections.abc import Mapping


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,16 +54,9 @@ class Job(object):
def __init__(self, project, statepoint, _id=None):
self._project = project

# Ensure that the job id is configured
if _id is None:
self._statepoint = json.loads(json.dumps(statepoint))
self._id = calc_id(self._statepoint)
else:
self._statepoint = dict(statepoint)
self._id = _id

# Prepare job statepoint
self._sp = SyncedAttrDict(self._statepoint, parent=_sp_save_hook(self))
# Set statepoint and id
self._statepoint = SyncedAttrDict(statepoint, parent=_sp_save_hook(self))
self._id = calc_id(self._statepoint()) if _id is None else _id

# Prepare job working directory
self._wd = os.path.join(project.workspace(), self._id)
Expand Down Expand Up @@ -155,9 +144,8 @@ def reset_statepoint(self, new_statepoint):
else:
raise
# Update this instance
self._statepoint = dst._statepoint
self._statepoint = SyncedAttrDict(dst._statepoint._as_dict(), parent=_sp_save_hook(self))
self._id = dst._id
self._sp = SyncedAttrDict(self._statepoint, parent=_sp_save_hook(self))
self._wd = dst._wd
self._fn_doc = dst._fn_doc
self._document = None
Expand Down Expand Up @@ -214,43 +202,21 @@ def statepoint(self):
`sp_dict = job.statepoint()` instead of `sp = job.statepoint`.
For more information, see :class:`~signac.JSONDict`.
"""
if self._sp is None:
self._sp = SyncedAttrDict(self._statepoint, parent=_sp_save_hook(self))
return self._sp
return self._statepoint

@statepoint.setter
def statepoint(self, new_sp):
self._reset_sp(new_sp)

@property
def sp(self):
""" Alias for :attr:`Job.statepoint`.
.. warning::
As with :attr:`Job.statepoint`, use `job.sp()` instead of
`job.sp` if you need a deep copy that will not modify the
underlying persistent JSON file.
"""
"Alias for :attr:`Job.statepoint`."
return self.statepoint

@sp.setter
def sp(self, new_sp):
self.statepoint = new_sp

def _reset_document(self, new_doc):
if not isinstance(new_doc, Mapping):
raise ValueError("The document must be a mapping.")
dirname, filename = os.path.split(self._fn_doc)
fn_tmp = os.path.join(dirname, '._{uid}_{fn}'.format(
uid=uuid.uuid4(), fn=filename))
with open(fn_tmp, 'wb') as tmpfile:
tmpfile.write(json.dumps(new_doc).encode())
if six.PY2:
os.rename(fn_tmp, self._fn_doc)
else:
os.replace(fn_tmp, self._fn_doc)

@property
def document(self):
"""The document associated with this job.
Expand All @@ -274,7 +240,7 @@ def document(self):

@document.setter
def document(self, new_doc):
self._reset_document(new_doc)
self.document.reset(new_doc)

@property
def doc(self):
Expand Down
20 changes: 5 additions & 15 deletions signac/contrib/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from .errors import DestinationExistsError
from .errors import JobsCorruptedError
if six.PY2:
from collections import Mapping, Iterable
from collections import Iterable
else:
from collections.abc import Mapping, Iterable
from collections.abc import Iterable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -259,17 +259,7 @@ def isfile(self, filename):
return os.path.isfile(self.fn(filename))

def _reset_document(self, new_doc):
if not isinstance(new_doc, Mapping):
raise ValueError("The document must be a mapping.")
dirname, filename = os.path.split(self._fn_doc)
fn_tmp = os.path.join(dirname, '._{uid}_{fn}'.format(
uid=uuid.uuid4(), fn=filename))
with open(fn_tmp, 'wb') as tmpfile:
tmpfile.write(json.dumps(new_doc).encode())
if six.PY2:
os.rename(fn_tmp, self._fn_doc)
else:
os.replace(fn_tmp, self._fn_doc)
self.document.reset(new_doc)

@property
def document(self):
Expand Down Expand Up @@ -372,7 +362,7 @@ def open_job(self, statepoint=None, id=None):
# second best case
job = self.Job(project=self, statepoint=statepoint)
if job._id not in self._sp_cache:
self._sp_cache[job._id] = dict(job._statepoint)
self._sp_cache[job._id] = job.statepoint._as_dict()
return job
elif id in self._sp_cache:
# optimal case
Expand Down Expand Up @@ -719,7 +709,7 @@ def write_statepoints(self, statepoints=None, fn=None, indent=2):

def _register(self, job):
"Register the job within the local index."
self._sp_cache[job._id] = dict(job._statepoint)
self._sp_cache[job._id] = job._statepoint._as_dict()

def _get_statepoint_from_workspace(self, jobid):
"Attempt to read the statepoint from the workspace."
Expand Down
22 changes: 22 additions & 0 deletions signac/core/jsondict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@
import logging
from tempfile import mkstemp
from contextlib import contextmanager
from copy import copy

from .errors import Error
from . import json
from .attrdict import SyncedAttrDict
from ..common import six
if six.PY2:
from collections import Mapping
else:
from collections.abc import Mapping


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -297,6 +302,23 @@ def _save(self, data=None):
with open(self._filename, 'wb') as file:
file.write(blob)

def reset(self, data):
"""Replace the document contents with data."""
if isinstance(data, Mapping):
with self._suspend_sync():
backup = copy(self._data)
try:
self._data = {
self._validate_key(k): self._dfs_convert(v)
for k, v in data.items()
}
self._save()
except BaseException: # rollback
self._data = backup
raise
else:
raise ValueError("The document must be a mapping.")

@contextmanager
def buffered(self):
buffered_dict = BufferedSyncedAttrDict(self, parent=self)
Expand Down
28 changes: 19 additions & 9 deletions signac/core/synceddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def outer_wrapped_in_load_and_save(*args, **kwargs):

class _SyncedDict(MutableMapping):

VALID_KEY_TYPES = six.string_types + (int, bool, type(None))

def __init__(self, initialdata=None, parent=None):
self._suspend_sync_ = 1
self._parent = parent
Expand All @@ -89,16 +91,22 @@ def __init__(self, initialdata=None, parent=None):
}
self._suspend_sync_ = 0

@staticmethod
def _validate_key(key):
@classmethod
def _validate_key(cls, key):
"Emit a warning or raise an exception if key is invalid. Returns key."
if '.' in key:
from ..errors import InvalidKeyError
raise InvalidKeyError(
"\nThe use of '.' (dots) in keys is invalid.\n\n"
"See https://signac.io/document-wide-migration/ "
"for a recipe on how to replace dots in existing keys.")
return key
if isinstance(key, six.string_types):
if '.' in key:
from ..errors import InvalidKeyError
raise InvalidKeyError(
"keys may not contain dots ('.'): {}".format(key))
else:
return key
elif isinstance(key, cls.VALID_KEY_TYPES):
return cls._validate_key(str(key))
else:
from ..errors import KeyTypeError
raise KeyTypeError(
"keys must be str, int, bool or None, not {}".format(type(key).__name__))

def _dfs_convert(self, root):
if type(root) is type(self):
Expand All @@ -110,6 +118,8 @@ def _dfs_convert(self, root):
for k in root:
ret[k] = root[k]
return ret
elif type(root) is tuple:
return _SyncedList(root, self)
elif type(root) is list:
return _SyncedList(root, self)
elif NUMPY:
Expand Down
4 changes: 4 additions & 0 deletions signac/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class InvalidKeyError(ValueError):
"""Raised when a user uses a non-conforming key."""


class KeyTypeError(TypeError):
"""Raised when a user uses a key of invalid type."""


__all__ = [
'Error',
'BufferException',
Expand Down
68 changes: 61 additions & 7 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from signac.errors import DestinationExistsError
from signac.errors import JobsCorruptedError
from signac.errors import InvalidKeyError
from signac.errors import KeyTypeError

if six.PY2:
from tempdir import TemporaryDirectory
Expand Down Expand Up @@ -192,13 +193,16 @@ def test_interface_read_write(self):
self.assertEqual(getattr(job.sp.g, x), sp['g'][x])
self.assertEqual(job.sp[x], sp[x])
a = [1, 1.0, '1.0', True, None]
b = list(a) + [a] + [tuple(a)]
for v in b:
for x in ('a', 'b', 'c', 'd', 'e'):
setattr(job.sp, x, v)
self.assertEqual(getattr(job.sp, x), v)
setattr(job.sp.g, x, v)
self.assertEqual(getattr(job.sp.g, x), v)
for x in ('a', 'b', 'c', 'd', 'e'):
setattr(job.sp, x, a)
self.assertEqual(getattr(job.sp, x), a)
setattr(job.sp.g, x, a)
self.assertEqual(getattr(job.sp.g, x), a)
t = (1, 2, 3) # tuple
job.sp.t = t
self.assertEqual(job.sp.t, list(t)) # implicit conversion
job.sp.g.t = t
self.assertEqual(job.sp.g.t, list(t))

def test_interface_job_identity_change(self):
job = self.open_job({'a': 0})
Expand Down Expand Up @@ -328,6 +332,56 @@ def test_interface_multiple_changes(self):
self.assertEqual(job.sp, job2.sp)
self.assertEqual(job.get_id(), job2.get_id())

def test_valid_sp_key_types(self):
job = self.open_job(dict(invalid_key=True)).init()

class A:
pass
for key in ('0', 0, True, False, None):
job.sp[key] = 'test'
self.assertIn(str(key), job.sp)

def test_invalid_sp_key_types(self):
job = self.open_job(dict(invalid_key=True)).init()

class A:
pass
for key in (0.0, A(), (1, 2, 3)):
with self.assertRaises(KeyTypeError):
job.sp[key] = 'test'
with self.assertRaises(KeyTypeError):
job.sp = {key: 'test'}
for key in ([], {}, dict()):
with self.assertRaises(TypeError):
job.sp[key] = 'test'
with self.assertRaises(TypeError):
job.sp = {key: 'test'}

def test_valid_doc_key_types(self):
job = self.open_job(dict(invalid_key=True)).init()

class A:
pass
for key in ('0', 0, True, False, None):
job.doc[key] = 'test'
self.assertIn(str(key), job.doc)

def test_invalid_doc_key_types(self):
job = self.open_job(dict(invalid_key=True)).init()

class A:
pass
for key in (0.0, A(), (1, 2, 3)):
with self.assertRaises(KeyTypeError):
job.doc[key] = 'test'
with self.assertRaises(KeyTypeError):
job.doc = {key: 'test'}
for key in ([], {}, dict()):
with self.assertRaises(TypeError):
job.doc[key] = 'test'
with self.assertRaises(TypeError):
job.doc = {key: 'test'}


class ConfigTest(BaseJobTest):

Expand Down
23 changes: 23 additions & 0 deletions tests/test_jsondict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from signac.core.jsondict import JSONDict
from signac.common import six
from signac.errors import InvalidKeyError
from signac.errors import KeyTypeError

if six.PY2:
from tempdir import TemporaryDirectory
Expand Down Expand Up @@ -231,6 +232,28 @@ def test_keys_with_dots(self):
with self.assertRaises(InvalidKeyError):
jsd['a.b'] = None

def test_keys_valid_type(self):
jsd = self.get_json_dict()

class MyStr(str):
pass
for key in ('key', MyStr('key'), 0, None, True):
d = jsd[key] = self.get_testdata()
self.assertIn(str(key), jsd)
self.assertEqual(jsd[str(key)], d)

def test_keys_invalid_type(self):
jsd = self.get_json_dict()

class A:
pass
for key in (0.0, A(), (1, 2, 3)):
with self.assertRaises(KeyTypeError):
jsd[key] = self.get_testdata()
for key in ([], {}, dict()):
with self.assertRaises(TypeError):
jsd[key] = self.get_testdata()


class JSONDictWriteConcernTest(JSONDictTest):

Expand Down

0 comments on commit 34ac45c

Please sign in to comment.