Skip to content

Commit

Permalink
add type hints and fix wrong doc str return type in phase_diagram.py
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Oct 12, 2022
1 parent ebf6ba2 commit 4904bbc
Showing 1 changed file with 44 additions and 38 deletions.
82 changes: 44 additions & 38 deletions pymatgen/analysis/phase_diagram.py
Expand Up @@ -26,7 +26,7 @@
from tqdm import tqdm

from pymatgen.analysis.reaction_calculator import Reaction, ReactionError
from pymatgen.core.composition import Composition, SpeciesLike
from pymatgen.core.composition import Composition
from pymatgen.core.periodic_table import DummySpecies, Element, get_el_sp
from pymatgen.entries import Entry
from pymatgen.util.coord import Simplex, in_coord_list
Expand Down Expand Up @@ -336,7 +336,13 @@ class PhaseDiagram(MSONable):
formation_energy_tol = 1e-11
numerical_tol = 1e-8

def __init__(self, entries, elements: Sequence[SpeciesLike] = (), *, computed_data=None) -> None:
def __init__(
self,
entries: Sequence[PDEntry] | set[PDEntry],
elements: Sequence[Element] = (),
*,
computed_data: dict[str, Any] = None,
) -> None:
"""
Args:
entries (list[PDEntry]): A list of PDEntry-like objects having an
Expand All @@ -358,6 +364,7 @@ def __init__(self, entries, elements: Sequence[SpeciesLike] = (), *, computed_da
computed_data = self._compute()
else:
computed_data = MontyDecoder().process_decoded(computed_data)
assert isinstance(computed_data, dict) # type narrowing to appease mypy
self.computed_data = computed_data
self.facets = computed_data["facets"]
self.simplexes = computed_data["simplexes"]
Expand Down Expand Up @@ -431,13 +438,13 @@ def _compute(self):
# Use only entries with negative formation energy
vec = [el_refs[el].energy_per_atom for el in elements] + [-1]
form_e = -np.dot(data, vec)
inds = np.where(form_e < -PhaseDiagram.formation_energy_tol)[0].tolist()
idx = np.where(form_e < -PhaseDiagram.formation_energy_tol)[0].tolist()

# Add the elemental references
inds.extend([min_entries.index(el) for el in el_refs.values()])
idx.extend([min_entries.index(el) for el in el_refs.values()])

qhull_entries = [min_entries[i] for i in inds]
qhull_data = data[inds][:, 1:]
qhull_entries = [min_entries[i] for i in idx]
qhull_data = data[idx][:, 1:]

# Add an extra point to enforce full dimensionality.
# This point will be present in all upper hull facets.
Expand Down Expand Up @@ -1473,25 +1480,32 @@ class PatchedPhaseDiagram(PhaseDiagram):
PhaseDiagrams within the PatchedPhaseDiagram.
pds ({str: PhaseDiagram}): Dictionary of PhaseDiagrams within the
PatchedPhaseDiagram.
all_entries ([PDEntry, ]): All entries provided for Phase Diagram construction.
all_entries (list[PDEntry]): All entries provided for Phase Diagram construction.
Note that this does not mean that all these entries are actually used in
the phase diagram. For example, this includes the positive formation energy
entries that are filtered out before Phase Diagram construction.
min_entries ([PDEntry, ]): List of the lowest energy entries for each composition
min_entries (list[PDEntry]): List of the lowest energy entries for each composition
in the data provided for Phase Diagram construction.
el_refs ([PDEntry, ]): List of elemental references for the phase diagrams.
el_refs (list[PDEntry]): List of elemental references for the phase diagrams.
These are entries corresponding to the lowest energy element entries for
simple compositional phase diagrams.
elements ([Element, ]): List of elements in the phase diagram.
elements (list[Element]): List of elements in the phase diagram.
"""

def __init__(self, entries, elements=None, keep_all_spaces=False, verbose=False):
def __init__(
self,
entries: Sequence[PDEntry] | set[PDEntry],
elements: Sequence[Element] = None,
keep_all_spaces: bool = False,
verbose: bool = False,
computed_data: dict[str, Any] = None,
) -> None:
"""
Args:
entries ([PDEntry, ]): A list of PDEntry-like objects having an
entries (list[PDEntry]): A list of PDEntry-like objects having an
energy, energy_per_atom and composition.
elements ([Element, ], optional): Optional list of elements in the phase
elements (list[Element], optional): Optional list of elements in the phase
diagram. If set to None, the elements are determined from
the entries themselves and are sorted alphabetically.
If specified, element ordering (e.g. for pd coordinates)
Expand All @@ -1502,22 +1516,20 @@ def __init__(self, entries, elements=None, keep_all_spaces=False, verbose=False)
if elements is None:
elements = sorted({els for e in entries for els in e.composition.elements})

elements = list(elements)

self.dim = len(elements)

entries = sorted(entries, key=lambda e: e.composition.reduced_composition)

el_refs = {}
el_refs: dict[Element, PDEntry] = {}
min_entries = []
all_entries = []
for c, g in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
g = list(g)
min_entry = min(g, key=lambda e: e.energy_per_atom)
all_entries: list[PDEntry] = []
for c, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
group = list(group_iter)
min_entry = min(group, key=lambda e: e.energy_per_atom)
if c.is_element:
el_refs[c.elements[0]] = min_entry
min_entries.append(min_entry)
all_entries.extend(g)
all_entries.extend(group)

if len(el_refs) < self.dim:
missing = set(elements) - set(el_refs)
Expand Down Expand Up @@ -1555,34 +1567,28 @@ def __init__(self, entries, elements=None, keep_all_spaces=False, verbose=False)
refer = (s for s in spaces if len(s) > i)
systems.extend([t for t in test if not any(t.issubset(r) for r in refer)])

spaces = systems

# Calculate pds for smaller dimension spaces first
spaces = sorted(spaces, key=len, reverse=False)

pds = [self._get_pd_patch_for_space(s) for s in tqdm(spaces, disable=(not verbose))]
pds = dict(pds)
spaces = {*systems}

# TODO comprhys: refactor to have self._compute method to allow serialisation
self.spaces = spaces
self.pds = pds
self.spaces = sorted(spaces, key=len, reverse=False) # Calculate pds for smaller dimension spaces first
self.pds = dict(self._get_pd_patch_for_space(s) for s in tqdm(self.spaces, disable=(not verbose)))
self.all_entries = all_entries
self.el_refs = el_refs
self.elements = elements

# Add terminal elements as we may not have PD patches including them
# NOTE add el_refs in case no multielement entries are present for el
_stable_entries = {se for pd in pds.values() for se in pd._stable_entries}
_stable_entries = {se for pd in self.pds.values() for se in pd._stable_entries}
self._stable_entries = tuple(_stable_entries.union(self.el_refs.values()))
self._stable_spaces = tuple(frozenset(e.composition.elements) for e in self._stable_entries)

def __repr__(self):
return f"{type(self).__name__}\n Covering {len(self.spaces)} Sub-Spaces"
return f"{type(self).__name__} covering {len(self.spaces)} sub-spaces"

def as_dict(self):
def as_dict(self) -> dict[str, Any]:
"""
Returns:
MSONable dictionary representation of PatchedPhaseDiagram
dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram
"""
return {
"@module": type(self).__module__,
Expand Down Expand Up @@ -1618,15 +1624,15 @@ def from_dict(cls, d):
# get_decomp_and_phase_separation_energy(),
# get_phase_separation_energy()

def get_pd_for_entry(self, entry):
def get_pd_for_entry(self, entry: Entry | Composition) -> PhaseDiagram:
"""
Get the possible phase diagrams for an entry
Args:
entry (PDEntry/Composition): a PDEntry or Composition like entry
entry (PDEntry | Composition): A PDEntry or Composition-like object
Returns:
Dictionary of {space: PhaseDiagram} that the entry is part of
PhaseDiagram: phase diagram that the entry is part of
"""
if isinstance(entry, Composition):
entry_space = frozenset(entry.elements)
Expand All @@ -1642,7 +1648,7 @@ def get_pd_for_entry(self, entry):

raise ValueError(f"No suitable PhaseDiagrams found for {entry}.")

def get_decomposition(self, comp):
def get_decomposition(self, comp: Composition) -> dict[PDEntry, float]:
"""
See PhaseDiagram
Expand Down

0 comments on commit 4904bbc

Please sign in to comment.