Skip to content

Commit

Permalink
Merge branch 'master' into type
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Apr 11, 2024
2 parents 513a2f4 + dbaba61 commit f625792
Show file tree
Hide file tree
Showing 8 changed files with 309 additions and 68 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.4
rev: v0.3.5
hooks:
- id: ruff
args: [--fix]
Expand Down
8 changes: 4 additions & 4 deletions monty/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def _is_in_owner_repo() -> bool:
# Only raise warning in code owner's repo CI
if (
_deadline is not None
and os.getenv("CI")
and os.getenv("CI") is not None
and datetime.now() > _deadline
and _is_in_owner_repo()
):
raise DeprecationWarning(
"This function should have been removed on {deadline:%Y-%m-%d}."
f"This function should have been removed on {_deadline:%Y-%m-%d}."
)

def craft_message(
Expand All @@ -94,7 +94,7 @@ def craft_message(
msg = f"{old.__name__} is deprecated"

if deadline is not None:
msg += f", and will be removed on {deadline:%Y-%m-%d}\n"
msg += f", and will be removed on {_deadline:%Y-%m-%d}\n"

if replacement is not None:
if isinstance(replacement, property):
Expand All @@ -120,7 +120,7 @@ def wrapped(*args, **kwargs):
# Convert deadline to datetime type
_deadline = datetime(*deadline) if deadline is not None else None

# Raise a CI warning after removal deadline
# Raise CI warning after removal deadline
raise_deadline_warning()

return deprecated_decorator
Expand Down
183 changes: 167 additions & 16 deletions monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import os
import pathlib
import pickle
import traceback
import types
from collections import OrderedDict, defaultdict
Expand All @@ -18,7 +19,8 @@
from inspect import getfullargspec
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID
from typing import Any, Dict
from uuid import UUID, uuid4

try:
import numpy as np
Expand Down Expand Up @@ -148,12 +150,12 @@ class MSONable:
class MSONClass(MSONable):
def __init__(self, a, b, c, d=1, **kwargs):
self.a = a
self.b = b
self._c = c
self._d = d
self.kwargs = kwargs
def __init__(self, a, b, c, d=1, **kwargs):
self.a = a
self.b = b
self._c = c
self._d = d
self.kwargs = kwargs
For such classes, you merely need to inherit from MSONable and you do not
need to implement your own as_dict or from_dict protocol.
Expand Down Expand Up @@ -231,6 +233,20 @@ def recursive_as_dict(obj):
d.update({"value": self.value}) # pylint: disable=E1101
return d

@staticmethod
def decoded_from_dict(d, name_object_map):
decoder = MontyDecoder()
decoder._name_object_map = name_object_map
decoded = {
k: decoder.process_decoded(v) for k, v in d.items() if not k.startswith("@")
}
return decoded

@classmethod
def _from_dict(cls, d, name_object_map):
decoded = MSONable.decoded_from_dict(d, name_object_map=name_object_map)
return cls(**decoded)

@classmethod
def from_dict(cls, d: dict) -> Self:
"""
Expand All @@ -240,11 +256,7 @@ def from_dict(cls, d: dict) -> Self:
Returns:
MSONable class.
"""
decoded = {
k: MontyDecoder().process_decoded(v)
for k, v in d.items()
if not k.startswith("@")
}
decoded = MSONable.decoded_from_dict(d, name_object_map=None)
return cls(**decoded)

def to_json(self) -> str:
Expand All @@ -253,6 +265,111 @@ def to_json(self) -> str:
"""
return json.dumps(self, cls=MontyEncoder)

def save(
self,
save_dir=None,
mkdir=True,
pickle_kwargs=None,
json_kwargs=None,
return_results=False,
strict=True,
):
"""Utility that uses the standard tools of MSONable to convert the
class to json format, but also save it to disk. In addition, this
method intelligently uses pickle to individually pickle class objects
that are not serializable, saving them separately. This maximizes the
readability of the saved class information while allowing _any_
class to be at least partially serializable to disk.
For a fully MSONable class, only a class.json file will be saved to
the location {save_dir}/class.json. For a partially MSONable class,
additional information will be saved to the save directory at
{save_dir}. This includes a pickled object for each attribute that
e serialized.
Parameters
----------
save_dir : os.PathLike
The directory to which to save the class information.
mkdir : bool
If True, makes the provided directory, including all parent
directories.
pickle_kwargs : dict
Keyword arguments to pass to pickle.dump.
json_kwargs : dict
Keyword arguments to pass to the serializer.
return_results : bool
If True, also returns the dictionary to save to disk, as well
as the mapping between the object_references and the objects
themselves.
strict : bool
If True, will not allow you to overwrite existing files.
Returns
-------
None or tuple
"""

if save_dir is None and not return_results:
raise ValueError("save_dir must be set and/or return_results must be True")

if pickle_kwargs is None:
pickle_kwargs = {}
if json_kwargs is None:
json_kwargs = {}
encoder = MontyEncoder(allow_unserializable_objects=True, **json_kwargs)
encoded = encoder.encode(self)

if save_dir is not None:
save_dir = Path(save_dir)
if mkdir:
save_dir.mkdir(exist_ok=True, parents=True)
json_path = save_dir / "class.json"
pickle_path = save_dir / "class.pkl"
if strict and json_path.exists():
raise FileExistsError(f"strict is true and file {json_path} exists")
if strict and pickle_path.exists():
raise FileExistsError(f"strict is true and file {pickle_path} exists")

with open(json_path, "w") as outfile:
outfile.write(encoded)
pickle.dump(
encoder._name_object_map,
open(pickle_path, "wb"),
**pickle_kwargs,
)

if return_results:
return encoded, encoder._name_object_map

@classmethod
def load(cls, load_dir):
"""Loads a class from a provided {load_dir}/class.json and
{load_dir}/class.pkl file (if necessary).
Parameters
----------
load_dir : os.PathLike
The directory from which to reload the class from.
Returns
-------
MSONable
An instance of the class being reloaded.
"""

load_dir = Path(load_dir)

json_path = load_dir / "class.json"
pickle_path = load_dir / "class.pkl"

with open(json_path, "r") as infile:
d = json.loads(infile.read())
name_object_map = pickle.load(open(pickle_path, "rb"))
decoded = MSONable.decoded_from_dict(d, name_object_map)
klass = cls(**decoded)
return klass

def unsafe_hash(self):
"""
Returns an hash of the current object. This uses a generic but low
Expand Down Expand Up @@ -376,7 +493,19 @@ class MontyEncoder(json.JSONEncoder):
json.dumps(object, cls=MontyEncoder)
"""

def default(self, o) -> dict:
def __init__(self, *args, allow_unserializable_objects=False, **kwargs):
super().__init__(*args, **kwargs)
self._track_unserializable_objects = allow_unserializable_objects
self._name_object_map: Dict[str, Any] = {}
self._index = 0

def _update_name_object_map(self, o):
name = f"{self._index:012}-{str(uuid4())}"
self._index += 1
self._name_object_map[name] = o
return {"@object_reference": name}

def default(self, o) -> dict: # pylint: disable=E0202
"""
Overriding default method for JSON encoding. This method does two
things: (a) If an object has a to_dict property, return the to_dict
Expand Down Expand Up @@ -445,7 +574,13 @@ def default(self, o) -> dict:
return {"@module": "bson.objectid", "@class": "ObjectId", "oid": str(o)}

if callable(o) and not isinstance(o, MSONable):
return _serialize_callable(o)
try:
return _serialize_callable(o)
except AttributeError as e:
# Some callables may not have instance __name__
if self._track_unserializable_objects:
return self._update_name_object_map(o)
raise AttributeError(e)

try:
if pydantic is not None and isinstance(o, pydantic.BaseModel):
Expand All @@ -461,6 +596,11 @@ def default(self, o) -> dict:
d = o.as_dict()
elif isinstance(o, Enum):
d = {"value": o.value}
elif self._track_unserializable_objects:
# Last resort logic. We keep track of some name of the object
# as a reference, and instead of the object, store that
# name, which of course is json-serializable
d = self._update_name_object_map(o)
else:
raise TypeError(
f"Object of type {o.__class__.__name__} is not JSON serializable"
Expand Down Expand Up @@ -497,13 +637,19 @@ class MontyDecoder(json.JSONDecoder):
json.loads(json_string, cls=MontyDecoder)
"""

def process_decoded(self, d: Any):
_name_object_map = None

def process_decoded(self, d):
"""
Recursive method to support decoding dicts and lists containing
pymatgen objects.
"""

if isinstance(d, dict):
if "@module" in d and "@class" in d:
if "@object_reference" in d and self._name_object_map is not None:
name = d["@object_reference"]
return self._name_object_map.pop(name)
elif "@module" in d and "@class" in d:
modname = d["@module"]
classname = d["@class"]
if cls_redirect := MSONable.REDIRECT.get(modname, {}).get(classname):
Expand Down Expand Up @@ -566,6 +712,11 @@ def process_decoded(self, d: Any):
if hasattr(mod, classname):
cls_ = getattr(mod, classname)
data = {k: v for k, v in d.items() if not k.startswith("@")}
if hasattr(cls_, "_from_dict"):
# New functionality with save/load requires this
return cls_._from_dict(
data, name_object_map=self._name_object_map
)
if hasattr(cls_, "from_dict"):
return cls_.from_dict(data)
if issubclass(cls_, Enum):
Expand Down
4 changes: 2 additions & 2 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ruamel.yaml==0.18.6
msgpack==1.0.8
tqdm==4.66.2
pymongo==4.6.3
pandas==2.2.1
pandas==2.2.2
orjson==3.10.0
types-orjson==3.6.2
types-requests==2.31.0.20240311
types-requests==2.31.0.20240406
Loading

0 comments on commit f625792

Please sign in to comment.