Skip to content

Commit

Permalink
Merge pull request #2684 from materialsproject/more-ppd-tests
Browse files Browse the repository at this point in the history
~2x number of PatchedPhaseDiagrams tests
  • Loading branch information
janosh committed Oct 12, 2022
2 parents 8ae54c7 + 589bdfc commit db8ad7a
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 71 deletions.
100 changes: 53 additions & 47 deletions pymatgen/analysis/phase_diagram.py
Expand Up @@ -26,7 +26,7 @@
from tqdm import tqdm

from pymatgen.analysis.reaction_calculator import Reaction, ReactionError
from pymatgen.core.composition import Composition, SpeciesLike
from pymatgen.core.composition import Composition
from pymatgen.core.periodic_table import DummySpecies, Element, get_el_sp
from pymatgen.entries import Entry
from pymatgen.util.coord import Simplex, in_coord_list
Expand Down Expand Up @@ -336,7 +336,13 @@ class PhaseDiagram(MSONable):
formation_energy_tol = 1e-11
numerical_tol = 1e-8

def __init__(self, entries, elements: Sequence[SpeciesLike] = (), *, computed_data=None) -> None:
def __init__(
self,
entries: Sequence[PDEntry] | set[PDEntry],
elements: Sequence[Element] = (),
*,
computed_data: dict[str, Any] = None,
) -> None:
"""
Args:
entries (list[PDEntry]): A list of PDEntry-like objects having an
Expand All @@ -358,6 +364,7 @@ def __init__(self, entries, elements: Sequence[SpeciesLike] = (), *, computed_da
computed_data = self._compute()
else:
computed_data = MontyDecoder().process_decoded(computed_data)
assert isinstance(computed_data, dict) # type narrowing to appease mypy
self.computed_data = computed_data
self.facets = computed_data["facets"]
self.simplexes = computed_data["simplexes"]
Expand Down Expand Up @@ -409,13 +416,13 @@ def _compute(self):
el_refs = {}
min_entries = []
all_entries = []
for c, g in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
g = list(g)
min_entry = min(g, key=lambda e: e.energy_per_atom)
if c.is_element:
el_refs[c.elements[0]] = min_entry
for composition, group in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
group = list(group)
min_entry = min(group, key=lambda e: e.energy_per_atom)
if composition.is_element:
el_refs[composition.elements[0]] = min_entry
min_entries.append(min_entry)
all_entries.extend(g)
all_entries.extend(group)

if len(el_refs) < dim:
missing = set(elements) - set(el_refs)
Expand All @@ -431,13 +438,13 @@ def _compute(self):
# Use only entries with negative formation energy
vec = [el_refs[el].energy_per_atom for el in elements] + [-1]
form_e = -np.dot(data, vec)
inds = np.where(form_e < -PhaseDiagram.formation_energy_tol)[0].tolist()
idx = np.where(form_e < -PhaseDiagram.formation_energy_tol)[0].tolist()

# Add the elemental references
inds.extend([min_entries.index(el) for el in el_refs.values()])
idx.extend([min_entries.index(el) for el in el_refs.values()])

qhull_entries = [min_entries[i] for i in inds]
qhull_data = data[inds][:, 1:]
qhull_entries = [min_entries[i] for i in idx]
qhull_data = data[idx][:, 1:]

# Add an extra point to enforce full dimensionality.
# This point will be present in all upper hull facets.
Expand Down Expand Up @@ -1473,25 +1480,32 @@ class PatchedPhaseDiagram(PhaseDiagram):
PhaseDiagrams within the PatchedPhaseDiagram.
pds ({str: PhaseDiagram}): Dictionary of PhaseDiagrams within the
PatchedPhaseDiagram.
all_entries ([PDEntry, ]): All entries provided for Phase Diagram construction.
all_entries (list[PDEntry]): All entries provided for Phase Diagram construction.
Note that this does not mean that all these entries are actually used in
the phase diagram. For example, this includes the positive formation energy
entries that are filtered out before Phase Diagram construction.
min_entries ([PDEntry, ]): List of the lowest energy entries for each composition
min_entries (list[PDEntry]): List of the lowest energy entries for each composition
in the data provided for Phase Diagram construction.
el_refs ([PDEntry, ]): List of elemental references for the phase diagrams.
el_refs (list[PDEntry]): List of elemental references for the phase diagrams.
These are entries corresponding to the lowest energy element entries for
simple compositional phase diagrams.
elements ([Element, ]): List of elements in the phase diagram.
elements (list[Element]): List of elements in the phase diagram.
"""

def __init__(self, entries, elements=None, keep_all_spaces=False, verbose=False):
def __init__(
self,
entries: Sequence[PDEntry] | set[PDEntry],
elements: Sequence[Element] = None,
keep_all_spaces: bool = False,
verbose: bool = False,
computed_data: dict[str, Any] = None,
) -> None:
"""
Args:
entries ([PDEntry, ]): A list of PDEntry-like objects having an
entries (list[PDEntry]): A list of PDEntry-like objects having an
energy, energy_per_atom and composition.
elements ([Element, ], optional): Optional list of elements in the phase
elements (list[Element], optional): Optional list of elements in the phase
diagram. If set to None, the elements are determined from
the entries themselves and are sorted alphabetically.
If specified, element ordering (e.g. for pd coordinates)
Expand All @@ -1502,22 +1516,20 @@ def __init__(self, entries, elements=None, keep_all_spaces=False, verbose=False)
if elements is None:
elements = sorted({els for e in entries for els in e.composition.elements})

elements = list(elements)

self.dim = len(elements)

entries = sorted(entries, key=lambda e: e.composition.reduced_composition)

el_refs = {}
el_refs: dict[Element, PDEntry] = {}
min_entries = []
all_entries = []
for c, g in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
g = list(g)
min_entry = min(g, key=lambda e: e.energy_per_atom)
if c.is_element:
el_refs[c.elements[0]] = min_entry
all_entries: list[PDEntry] = []
for composition, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
group = list(group_iter)
min_entry = min(group, key=lambda e: e.energy_per_atom)
if composition.is_element:
el_refs[composition.elements[0]] = min_entry
min_entries.append(min_entry)
all_entries.extend(g)
all_entries.extend(group)

if len(el_refs) < self.dim:
missing = set(elements) - set(el_refs)
Expand Down Expand Up @@ -1548,41 +1560,35 @@ def __init__(self, entries, elements=None, keep_all_spaces=False, verbose=False)
if not keep_all_spaces and len(spaces) > 1:
max_size = max(len(s) for s in spaces)

systems = []
systems = set()
# NOTE reduce the number of comparisons by only comparing to larger sets
for i in range(2, max_size + 1):
test = (s for s in spaces if len(s) == i)
refer = (s for s in spaces if len(s) > i)
systems.extend([t for t in test if not any(t.issubset(r) for r in refer)])
systems |= {t for t in test if not any(t.issubset(r) for r in refer)}

spaces = systems

# Calculate pds for smaller dimension spaces first
spaces = sorted(spaces, key=len, reverse=False)

pds = [self._get_pd_patch_for_space(s) for s in tqdm(spaces, disable=(not verbose))]
pds = dict(pds)

# TODO comprhys: refactor to have self._compute method to allow serialisation
self.spaces = spaces
self.pds = pds
self.spaces = sorted(spaces, key=len, reverse=False) # Calculate pds for smaller dimension spaces first
self.pds = dict(self._get_pd_patch_for_space(s) for s in tqdm(self.spaces, disable=(not verbose)))
self.all_entries = all_entries
self.el_refs = el_refs
self.elements = elements

# Add terminal elements as we may not have PD patches including them
# NOTE add el_refs in case no multielement entries are present for el
_stable_entries = {se for pd in pds.values() for se in pd._stable_entries}
_stable_entries = {se for pd in self.pds.values() for se in pd._stable_entries}
self._stable_entries = tuple(_stable_entries.union(self.el_refs.values()))
self._stable_spaces = tuple(frozenset(e.composition.elements) for e in self._stable_entries)

def __repr__(self):
return f"{type(self).__name__}\n Covering {len(self.spaces)} Sub-Spaces"
return f"{type(self).__name__} covering {len(self.spaces)} sub-spaces"

def as_dict(self):
def as_dict(self) -> dict[str, Any]:
"""
Returns:
MSONable dictionary representation of PatchedPhaseDiagram
dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram
"""
return {
"@module": type(self).__module__,
Expand Down Expand Up @@ -1618,15 +1624,15 @@ def from_dict(cls, d):
# get_decomp_and_phase_separation_energy(),
# get_phase_separation_energy()

def get_pd_for_entry(self, entry):
def get_pd_for_entry(self, entry: Entry | Composition) -> PhaseDiagram:
"""
Get the possible phase diagrams for an entry
Args:
entry (PDEntry/Composition): a PDEntry or Composition like entry
entry (PDEntry | Composition): A PDEntry or Composition-like object
Returns:
Dictionary of {space: PhaseDiagram} that the entry is part of
PhaseDiagram: phase diagram that the entry is part of
"""
if isinstance(entry, Composition):
entry_space = frozenset(entry.elements)
Expand All @@ -1642,7 +1648,7 @@ def get_pd_for_entry(self, entry):

raise ValueError(f"No suitable PhaseDiagrams found for {entry}.")

def get_decomposition(self, comp):
def get_decomposition(self, comp: Composition) -> dict[PDEntry, float]:
"""
See PhaseDiagram
Expand Down
92 changes: 68 additions & 24 deletions pymatgen/analysis/tests/test_phase_diagram.py
Expand Up @@ -93,7 +93,7 @@ def test_str(self):
self.assertEqual(str(pde), "PDEntry : Li1 Fe1 O2 with energy = 53.0000")

def test_read_csv(self):
entries = EntrySet.from_csv(str(module_dir / "pdentries_test.csv"))
entries = EntrySet.from_csv(module_dir / "pdentries_test.csv")
self.assertEqual(entries.chemsys, {"Li", "Fe", "O"}, "Wrong elements!")
self.assertEqual(len(entries), 490, "Wrong number of entries!")

Expand Down Expand Up @@ -149,7 +149,7 @@ def test_normalize(self):

class PhaseDiagramTest(unittest.TestCase):
def setUp(self):
self.entries = EntrySet.from_csv(str(module_dir / "pdentries_test.csv"))
self.entries = EntrySet.from_csv(module_dir / "pdentries_test.csv")
self.pd = PhaseDiagram(self.entries)
warnings.simplefilter("ignore")

Expand Down Expand Up @@ -634,7 +634,7 @@ def test_read_json(self):

class GrandPotentialPhaseDiagramTest(unittest.TestCase):
def setUp(self):
self.entries = EntrySet.from_csv(str(module_dir / "pdentries_test.csv"))
self.entries = EntrySet.from_csv(module_dir / "pdentries_test.csv")
self.pd = GrandPotentialPhaseDiagram(self.entries, {Element("O"): -5})
self.pd6 = GrandPotentialPhaseDiagram(self.entries, {Element("O"): -6})

Expand Down Expand Up @@ -671,7 +671,7 @@ def test_str(self):

class CompoundPhaseDiagramTest(unittest.TestCase):
def setUp(self):
self.entries = EntrySet.from_csv(str(module_dir / "pdentries_test.csv"))
self.entries = EntrySet.from_csv(module_dir / "pdentries_test.csv")
self.pd = CompoundPhaseDiagram(self.entries, [Composition("Li2O"), Composition("Fe2O3")])

def test_stable_entries(self):
Expand All @@ -697,9 +697,10 @@ def test_str(self):

class PatchedPhaseDiagramTest(unittest.TestCase):
def setUp(self):
self.entries = EntrySet.from_csv(str(module_dir / "reaction_entries_test.csv"))
# NOTE add He to test for correct behaviour despite no patches involving He
self.entries.add(PDEntry("He", -1.23))
self.entries = EntrySet.from_csv(module_dir / "reaction_entries_test.csv")
# NOTE add He to test for correct behavior despite no patches involving He
self.no_patch_entry = he_entry = PDEntry("He", -1.23)
self.entries.add(he_entry)

self.pd = PhaseDiagram(entries=self.entries)
self.ppd = PatchedPhaseDiagram(entries=self.entries)
Expand All @@ -712,34 +713,77 @@ def setUp(self):
self.novel_entries = [PDEntry(c, -39.8) for c in self.novel_comps]

def test_get_stable_entries(self):
self.assertEqual(self.pd.stable_entries, self.ppd.stable_entries)
assert self.pd.stable_entries == self.ppd.stable_entries

def test_get_qhull_entries(self):
# NOTE qhull_entry is an specially sorted list due to it's construction, we
# can't mimic this in ppd therefore just test if sorted versions are equal.
self.assertEqual(
sorted(self.pd.qhull_entries, key=lambda e: e.composition.reduced_composition),
sorted(self.ppd.qhull_entries, key=lambda e: e.composition.reduced_composition),
assert sorted(self.pd.qhull_entries, key=lambda e: e.composition) == sorted(
self.ppd.qhull_entries, key=lambda e: e.composition
)

def test_get_decomposition(self):
for c in self.novel_comps:
pd_decomp = self.pd.get_decomposition(c)
ppd_decomp = self.ppd.get_decomposition(c)

# NOTE unittest doesn't have an assert almost equal for dictionaries.
for e in pd_decomp:
self.assertAlmostEqual(pd_decomp[e], ppd_decomp[e], 7)
for comp in self.novel_comps:
decomp_pd = self.pd.get_decomposition(comp)
decomp_ppd = self.ppd.get_decomposition(comp)
assert decomp_pd == pytest.approx(decomp_ppd)

def test_get_phase_separation_energy(self):
for e in self.novel_entries:
self.assertAlmostEqual(self.pd.get_phase_separation_energy(e), self.ppd.get_phase_separation_energy(e), 7)
for entry in self.novel_entries:
e_phase_sep_pd = self.pd.get_phase_separation_energy(entry)
e_phase_sep_ppd = self.ppd.get_phase_separation_energy(entry)
assert np.isclose(e_phase_sep_pd, e_phase_sep_ppd)

def test_get_equilibrium_reaction_energy(self):
for e in self.pd.stable_entries:
self.assertAlmostEqual(
self.pd.get_equilibrium_reaction_energy(e), self.ppd.get_equilibrium_reaction_energy(e), 7
)
for entry in self.pd.stable_entries:
e_equi_rxn_pd = self.pd.get_equilibrium_reaction_energy(entry)
e_equi_rxn_pdd = self.ppd.get_equilibrium_reaction_energy(entry)
assert np.isclose(e_equi_rxn_pd, e_equi_rxn_pdd)

def test_get_form_energy(self):
for entry in self.pd.stable_entries:
e_form_pd = self.pd.get_form_energy(entry)
e_form_ppd = self.ppd.get_form_energy(entry)
assert np.isclose(e_form_pd, e_form_ppd)

def test_dimensionality(self):
assert self.pd.dim == self.ppd.dim

def test_get_hull_energy(self):
for comp in self.novel_comps:
e_hull_pd = self.pd.get_hull_energy(comp)
e_hull_ppd = self.ppd.get_hull_energy(comp)
assert np.isclose(e_hull_pd, e_hull_ppd)

def test_get_decomp_and_e_above_hull(self):
for entry in self.pd.stable_entries:
decomp_pd, e_above_hull_pd = self.pd.get_decomp_and_e_above_hull(entry)
decomp_ppd, e_above_hull_ppd = self.ppd.get_decomp_and_e_above_hull(entry)
assert decomp_pd == decomp_ppd
assert np.isclose(e_above_hull_pd, e_above_hull_ppd)

def test_repr(self):
assert repr(self.ppd) == str(self.ppd) == "PatchedPhaseDiagram covering 15 sub-spaces"

def test_as_from_dict(self):
ppd_dict = self.ppd.as_dict()
assert ppd_dict["@module"] == self.ppd.__class__.__module__
assert ppd_dict["@class"] == self.ppd.__class__.__name__
assert ppd_dict["all_entries"] == [entry.as_dict() for entry in self.ppd.all_entries]
assert ppd_dict["elements"] == [elem.as_dict() for elem in self.ppd.elements]
# test round-trip dict serialization
assert PatchedPhaseDiagram.from_dict(ppd_dict).as_dict() == ppd_dict

def test_get_pd_for_entry(self):
for entry in self.ppd.all_entries:
if entry == self.no_patch_entry:
continue
pd = self.ppd.get_pd_for_entry(entry)
# test that entry is in pd and pd can return valid decomp
assert isinstance(pd.get_decomposition(entry.composition), dict)

with pytest.raises(ValueError, match="No suitable PhaseDiagrams found for PDEntry"):
self.ppd.get_pd_for_entry(self.no_patch_entry)

def test_raises_on_missing_terminal_entries(self):
entry = PDEntry("FeO", -1.23)
Expand Down

0 comments on commit db8ad7a

Please sign in to comment.