Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for an MSONAtoms class that's an MSONable form of an ASE Atoms object #3619

Merged
merged 7 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 52 additions & 7 deletions pymatgen/io/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Atoms object and pymatgen Structure objects.
"""


from __future__ import annotations

import warnings
Expand All @@ -12,16 +11,19 @@
from typing import TYPE_CHECKING

import numpy as np
from monty.json import MSONable

from pymatgen.core.structure import Molecule, Structure

if TYPE_CHECKING:
from typing import Any

from numpy.typing import ArrayLike

from pymatgen.core.structure import SiteCollection

try:
from ase import Atoms
from ase.atoms import Atoms
from ase.calculators.singlepoint import SinglePointDFTCalculator
from ase.constraints import FixAtoms
from ase.spacegroup import Spacegroup
Expand All @@ -38,18 +40,41 @@
__date__ = "Mar 8, 2012"


class MSONAtoms(Atoms, MSONable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something here breaks imports in cases where the user doesn't have ASE installed. This seems to be breaking the latest mp_api as a result:

from mp_api.client.mprester import MPRester
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/mevans/repos/re2fractive/re2fractive/experiments/2_featurizing_refractive_index_data/.venv-mp-api/lib/python3.11/site-packages/mp_api/client/__init__.py", line 8, in <module>
    from .mprester import MPRester
  File "/home/mevans/repos/re2fractive/re2fractive/experiments/2_featurizing_refractive_index_data/.venv-mp-api/lib/python3.11/site-packages/mp_api/client/mprester.py", line 10, in <module>
    from emmet.core.electronic_structure import BSPathType
  File "/home/mevans/repos/re2fractive/re2fractive/experiments/2_featurizing_refractive_index_data/.venv-mp-api/lib/python3.11/site-packages/emmet/core/electronic_structure.py", line 11, in <module>
    from pymatgen.analysis.magnetism.analyzer import (
  File "/home/mevans/repos/re2fractive/re2fractive/experiments/2_featurizing_refractive_index_data/.venv-mp-api/lib/python3.11/site-packages/pymatgen/analysis/magnetism/__init__.py", line 5, in <module>
    from pymatgen.analysis.magnetism.analyzer import (
  File "/home/mevans/repos/re2fractive/re2fractive/experiments/2_featurizing_refractive_index_data/.venv-mp-api/lib/python3.11/site-packages/pymatgen/analysis/magnetism/analyzer.py", line 24, in <module>
    from pymatgen.transformations.advanced_transformations import MagOrderingTransformation, MagOrderParameterConstraint
  File "/home/mevans/repos/re2fractive/re2fractive/experiments/2_featurizing_refractive_index_data/.venv-mp-api/lib/python3.11/site-packages/pymatgen/transformations/advanced_transformations.py", line 33, in <module>
    from pymatgen.io.ase import AseAtomsAdaptor
  File "/home/mevans/repos/re2fractive/re2fractive/experiments/2_featurizing_refractive_index_data/.venv-mp-api/lib/python3.11/site-packages/pymatgen/io/ase.py", line 44, in <module>
    class MSONAtoms(Atoms, MSONable):
                    ^^^^^
NameError: name 'Atoms' is not defined

Will raise an issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see #3644

"""A custom subclass of ASE Atoms that is MSONable, including `.as_dict()` and `.from_dict()` methods."""

def as_dict(s: Atoms) -> dict[str, Any]:
from ase.io.jsonio import encode

# Normally, we would want to this to be a wrapper around atoms.todict() with @module and
# @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant
# to be used in a round-trip fashion and does not work properly with constraints.
# See ASE issue #1387.
return {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(s)}

def from_dict(d: dict[str, Any]) -> Atoms:
from ase.io.jsonio import decode

# Normally, we would want to this to be a wrapper around atoms.fromdict() with @module and
# @class key-value pairs inserted. However, atoms.todict()/atoms.fromdict() is not meant
# to be used in a round-trip fashion and does not work properly with constraints.
# See ASE issue #1387.
return decode(d["atoms_json"])


# NOTE: If making notable changes to this class, please ping @Andrew-S-Rosen on GitHub.
# There are some subtleties in here, particularly related to spins/charges.
class AseAtomsAdaptor:
"""Adaptor serves as a bridge between ASE Atoms and pymatgen objects."""

@staticmethod
def get_atoms(structure: SiteCollection, **kwargs) -> Atoms:
def get_atoms(structure: SiteCollection, msonable: bool = True, **kwargs) -> MSONAtoms | Atoms:
"""
Returns ASE Atoms object from pymatgen structure or molecule.

Args:
structure (SiteCollection): pymatgen Structure or Molecule
msonable (bool): Whether to return an MSONAtoms object, which is MSONable.
**kwargs: passed to the ASE Atoms constructor

Returns:
Expand All @@ -72,6 +97,9 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms:

atoms = Atoms(symbols=symbols, positions=positions, pbc=pbc, cell=cell, **kwargs)

if msonable:
atoms = MSONAtoms(atoms)

if "tags" in structure.site_properties:
atoms.set_tags(structure.site_properties["tags"])

Expand Down Expand Up @@ -142,7 +170,13 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms:

# Add any remaining site properties to the ASE Atoms object
for prop in structure.site_properties:
if prop not in ["magmom", "charge", "final_magmom", "final_charge", "selective_dynamics"]:
if prop not in [
"magmom",
"charge",
"final_magmom",
"final_charge",
"selective_dynamics",
]:
atoms.set_array(prop, np.array(structure.site_properties[prop]))
if any(oxi_states):
atoms.set_array("oxi_states", np.array(oxi_states))
Expand All @@ -154,7 +188,8 @@ def get_atoms(structure: SiteCollection, **kwargs) -> Atoms:
# Regenerate Spacegroup object from `.todict()` representation
if isinstance(atoms.info.get("spacegroup"), dict):
atoms.info["spacegroup"] = Spacegroup(
atoms.info["spacegroup"]["number"], setting=atoms.info["spacegroup"].get("setting", 1)
atoms.info["spacegroup"]["number"],
setting=atoms.info["spacegroup"].get("setting", 1),
)

# Atoms.calc <---> Structure.calc
Expand Down Expand Up @@ -216,7 +251,10 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs)
else:
unsupported_constraint_type = True
if unsupported_constraint_type:
warnings.warn("Only FixAtoms is supported by Pymatgen. Other constraints will not be set.", UserWarning)
warnings.warn(
"Only FixAtoms is supported by Pymatgen. Other constraints will not be set.",
UserWarning,
)
sel_dyn = [[False] * 3 if atom.index in constraint_indices else [True] * 3 for atom in atoms]
else:
sel_dyn = None
Expand All @@ -232,7 +270,14 @@ def get_structure(atoms: Atoms, cls: type[Structure] = Structure, **cls_kwargs)
if cls == Molecule:
structure = cls(symbols, positions, properties=properties, **cls_kwargs)
else:
structure = cls(lattice, symbols, positions, coords_are_cartesian=True, properties=properties, **cls_kwargs)
structure = cls(
lattice,
symbols,
positions,
coords_are_cartesian=True,
properties=properties,
**cls_kwargs,
)

# Atoms.calc <---> Structure.calc
if calc := getattr(atoms, "calc", None):
Expand Down
36 changes: 33 additions & 3 deletions tests/io/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pymatgen.core import Composition, Lattice, Molecule, Structure
from pymatgen.core.structure import StructureError
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.ase import AseAtomsAdaptor, MSONAtoms
from pymatgen.util.testing import TEST_FILES_DIR

ase = pytest.importorskip("ase")
Expand Down Expand Up @@ -147,7 +147,10 @@ def test_get_structure():
atoms = read(f"{TEST_FILES_DIR}/POSCAR_overlap")
struct = AseAtomsAdaptor.get_structure(atoms)
assert [s.species_string for s in struct] == atoms.get_chemical_symbols()
with pytest.raises(StructureError, match=f"sites are less than {struct.DISTANCE_TOLERANCE} Angstrom apart"):
with pytest.raises(
StructureError,
match=f"sites are less than {struct.DISTANCE_TOLERANCE} Angstrom apart",
):
struct = AseAtomsAdaptor.get_structure(atoms, validate_proximity=True)


Expand All @@ -169,7 +172,12 @@ def test_get_structure_mag():

@pytest.mark.parametrize(
"select_dyn",
[[True, True, True], [False, False, False], np.array([True, True, True]), np.array([False, False, False])],
[
[True, True, True],
[False, False, False],
np.array([True, True, True]),
np.array([False, False, False]),
],
)
def test_get_structure_dyn(select_dyn):
atoms = read(f"{TEST_FILES_DIR}/POSCAR")
Expand Down Expand Up @@ -289,3 +297,25 @@ def test_back_forth_v4():
# test document can be jsanitized and decoded
dct = jsanitize(molecule, strict=True, enum_values=True)
MontyDecoder().process_decoded(dct)


def test_msonable_atoms():
from ase.io.jsonio import encode

atoms = read(f"{TEST_FILES_DIR}/OUTCAR")
ref = {"@module": "ase.atoms", "@class": "Atoms", "atoms_json": encode(atoms)}
msonable_atoms = MSONAtoms(atoms)
assert msonable_atoms.as_dict() == ref
assert MSONAtoms.from_dict(ref) == atoms


def test_msonable_atoms_v2():
structure = Structure.from_file(f"{TEST_FILES_DIR}/POSCAR")

atoms = AseAtomsAdaptor.get_atoms(structure, msonable=True)
assert hasattr(atoms, "as_dict")
assert hasattr(atoms, "from_dict")

atoms = AseAtomsAdaptor.get_atoms(structure, msonable=False)
assert not hasattr(atoms, "as_dict")
assert not hasattr(atoms, "from_dict")