Skip to content

Commit

Permalink
Add back additional_criteria as an input to get_entries and `get_…
Browse files Browse the repository at this point in the history
…entries_in_chemsys` (#693)

* Add back additiona_critiera to get_entries

* Change task test input

* Revert task test

* Update generic rester tests

* Update gitignore and tests

* Update imports and default thread number for multithreading

* Remove formula in client test for tasks

* Allow top level entry related methods to work without de-serialization enabled

* Linting

* Fix energy key

* Reduce parallel test runs
  • Loading branch information
munrojm committed Oct 12, 2022
1 parent 2292fc1 commit 72eafaa
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
test:
strategy:
max-parallel: 6
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.8, 3.9]
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ ENV/
.project
.pydevproject

# fleet
.fleet

*~

.idea
Expand Down
10 changes: 9 additions & 1 deletion mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from pydantic import BaseSettings, Field
from mp_api.client import __file__ as root_dir
from multiprocessing import cpu_count
from typing import List
import os

CPU_COUNT = 8

try:
CPU_COUNT = cpu_count()
except NotImplementedError:
pass


class MAPIClientSettings(BaseSettings):
"""
Expand Down Expand Up @@ -41,7 +49,7 @@ class MAPIClientSettings(BaseSettings):
)

NUM_PARALLEL_REQUESTS: int = Field(
8, description="Number of parallel requests to send.",
CPU_COUNT, description="Number of parallel requests to send.",
)

MAX_RETRIES: int = Field(3, description="Maximum number of retries for requests.")
Expand Down
155 changes: 95 additions & 60 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import itertools
from multiprocessing.sharedctypes import Value
import warnings
from functools import lru_cache
from os import environ
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union

from emmet.core.charge_density import ChgcarDataDoc
from emmet.core.electronic_structure import BSPathType
from emmet.core.mpid import MPID
from emmet.core.settings import EmmetSettings
from emmet.core.summary import HasProps
from emmet.core.symmetry import CrystalSystem
from emmet.core.vasp.calc_types import CalcType
from pymatgen.analysis.magnetism import Ordering
from pymatgen.analysis.phase_diagram import PhaseDiagram
from pymatgen.analysis.pourbaix_diagram import IonEntry
from pymatgen.core import Composition, Element, Structure
from pymatgen.core import Element, Structure
from pymatgen.core.ion import Ion
from pymatgen.entries.computed_entries import ComputedEntry, ComputedStructureEntry
from pymatgen.io.vasp import Chgcar
Expand Down Expand Up @@ -132,15 +128,20 @@ def __init__(
self.session = BaseRester._create_session(
api_key=api_key, include_user_agent=include_user_agent
)
self.use_document_model = use_document_model
self.monty_decode = monty_decode

try:
from mpcontribs.client import Client

self.contribs = Client(api_key)
except ImportError:
self.contribs = None
warnings.warn("mpcontribs-client not installed. "
"Install the package to query MPContribs data, or construct pourbaix diagrams: "
"'pip install mpcontribs-client'")
warnings.warn(
"mpcontribs-client not installed. "
"Install the package to query MPContribs data, or construct pourbaix diagrams: "
"'pip install mpcontribs-client'"
)
except Exception as error:
self.contribs = None
warnings.warn(f"Problem loading MPContribs client: {error}")
Expand Down Expand Up @@ -186,15 +187,19 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def __getattr__(self, attr):
if attr == "alloys":
raise MPRestError("Alloy addon package not installed. "
"To query alloy data first install with: 'pip install pymatgen-analysis-alloys'")
raise MPRestError(
"Alloy addon package not installed. "
"To query alloy data first install with: 'pip install pymatgen-analysis-alloys'"
)
elif attr == "charge_density":
raise MPRestError("boto3 not installed. "
"To query charge density data first install with: 'pip install boto3'")
raise MPRestError(
"boto3 not installed. "
"To query charge density data first install with: 'pip install boto3'"
)
else:
raise AttributeError(
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)

def get_task_ids_associated_with_material_id(
self, material_id: str, calc_types: Optional[List[CalcType]] = None
Expand Down Expand Up @@ -446,7 +451,8 @@ def get_entries(
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,
sort_by_e_above_hull=False,
sort_by_e_above_hull: bool = False,
additional_criteria: dict = None,
) -> List[ComputedStructureEntry]:
"""
Get a list of ComputedEntries or ComputedStructureEntries corresponding
Expand Down Expand Up @@ -476,14 +482,20 @@ def get_entries(
conventional unit cell
sort_by_e_above_hull (bool): Whether to sort the list of entries by
e_above_hull in ascending order.
additional_criteria (dict): Any additional criteria to pass. The keys and values should
correspond to proper function inputs to `MPRester.thermo.search`. For instance,
if you are only interested in entries on the convex hull, you could pass
{"energy_above_hull": (0.0, 0.0)} or {"is_stable": True}.
Returns:
List ComputedStructureEntry objects.
"""

if inc_structure is not None:
warnings.warn("The 'inc_structure' argument is deprecated as structure "
"data is now always included in all returned entry objects.")
warnings.warn(
"The 'inc_structure' argument is deprecated as structure "
"data is now always included in all returned entry objects."
)

if isinstance(chemsys_formula_mpids, str):
chemsys_formula_mpids = [chemsys_formula_mpids]
Expand All @@ -497,6 +509,9 @@ def get_entries(
else:
input_params = {"formula": chemsys_formula_mpids}

if additional_criteria:
input_params = {**input_params, **additional_criteria}

entries = []

fields = ["entries"] if not property_data else ["entries"] + property_data
Expand All @@ -514,22 +529,25 @@ def get_entries(
)

for doc in docs:
for entry in doc.entries.values():
entry_list = doc.entries.values() if self.use_document_model else doc["entries"].values()
for entry in entry_list:
entry_dict = entry.as_dict() if self.monty_decode else entry
if not compatible_only:
entry.correction = 0.0
entry.energy_adjustments = []
entry_dict["correction"] = 0.0
entry_dict["energy_adjustments"] = []

if property_data:
for property in property_data:
entry.data[property] = doc.dict()[property]
entry_dict["data"][property] = doc.dict()[property] if self.use_document_model else doc[
property]

if conventional_unit_cell:

s = SpacegroupAnalyzer(entry.structure).get_conventional_standard_structure()
site_ratio = (len(s) / len(entry.structure))
new_energy = entry.uncorrected_energy * site_ratio
entry_struct = Structure.from_dict(entry_dict["structure"])
s = SpacegroupAnalyzer(entry_struct).get_conventional_standard_structure()
site_ratio = len(s) / len(entry_struct)
new_energy = entry_dict["energy"] * site_ratio

entry_dict = entry.as_dict()
entry_dict["energy"] = new_energy
entry_dict["structure"] = s.as_dict()
entry_dict["correction"] = 0.0
Expand All @@ -540,7 +558,7 @@ def get_entries(
for correction in entry_dict["energy_adjustments"]:
correction["n_atoms"] *= site_ratio

entry = ComputedStructureEntry.from_dict(entry_dict)
entry = ComputedStructureEntry.from_dict(entry_dict) if self.monty_decode else entry_dict

entries.append(entry)

Expand Down Expand Up @@ -575,9 +593,11 @@ def get_pourbaix_entries(
# imports are not top-level due to expense
from pymatgen.analysis.pourbaix_diagram import PourbaixEntry
from pymatgen.entries.compatibility import (
Compatibility, MaterialsProject2020Compatibility,
Compatibility,
MaterialsProject2020Compatibility,
MaterialsProjectAqueousCompatibility,
MaterialsProjectCompatibility)
MaterialsProjectCompatibility,
)
from pymatgen.entries.computed_entries import ComputedEntry

if solid_compat == "MaterialsProjectCompatibility":
Expand Down Expand Up @@ -638,8 +658,7 @@ def get_pourbaix_entries(
# could be removed
if use_gibbs:
# replace the entries with GibbsComputedStructureEntry
from pymatgen.entries.computed_entries import \
GibbsComputedStructureEntry
from pymatgen.entries.computed_entries import GibbsComputedStructureEntry

ion_ref_entries = GibbsComputedStructureEntry.from_entries(
ion_ref_entries, temp=use_gibbs
Expand Down Expand Up @@ -846,11 +865,14 @@ def get_ion_entries(

return ion_entries

def get_entry_by_material_id(self, material_id: str,
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,):
def get_entry_by_material_id(
self,
material_id: str,
compatible_only: bool = True,
inc_structure: bool = None,
property_data: List[str] = None,
conventional_unit_cell: bool = False,
):
"""
Get all ComputedEntry objects corresponding to a material_id.
Expand All @@ -877,14 +899,17 @@ def get_entry_by_material_id(self, material_id: str,
Returns:
List of ComputedEntry or ComputedStructureEntry object.
"""
return self.get_entries(material_id,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell)
return self.get_entries(
material_id,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell,
)

def get_entries_in_chemsys(
self, elements: Union[str, List[str]],
self,
elements: Union[str, List[str]],
use_gibbs: Optional[int] = None,
compatible_only: bool = True,
inc_structure: bool = None,
Expand Down Expand Up @@ -924,17 +949,14 @@ def get_entries_in_chemsys(
input parameters in the 'MPRester.thermo.available_fields' list.
conventional_unit_cell (bool): Whether to get the standard
conventional unit cell
additional_criteria (dict): *This is a deprecated argument*. To obtain entry objects
with additional criteria, use the `MPRester.thermo.search` method directly.
additional_criteria (dict): Any additional criteria to pass. The keys and values should
correspond to proper function inputs to `MPRester.thermo.search`. For instance,
if you are only interested in entries on the convex hull, you could pass
{"energy_above_hull": (0.0, 0.0)} or {"is_stable": True}.
Returns:
List of ComputedStructureEntries.
"""

if additional_criteria is not None:
warnings.warn("The 'additional_criteria' argument is deprecated. "
"To obtain entry objects with additional criteria, use "
"the 'MPRester.thermo.search' method directly")

if isinstance(elements, str):
elements = elements.split("-")

Expand All @@ -945,19 +967,29 @@ def get_entries_in_chemsys(

entries = [] # type: List[ComputedEntry]

entries.extend(self.get_entries(all_chemsyses,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell))
entries.extend(
self.get_entries(
all_chemsyses,
compatible_only=compatible_only,
inc_structure=inc_structure,
property_data=property_data,
conventional_unit_cell=conventional_unit_cell,
additional_criteria=additional_criteria,
)
)

if not self.monty_decode:
entries = [ComputedStructureEntry.from_dict(entry) for entry in entries]

if use_gibbs:
# replace the entries with GibbsComputedStructureEntry
from pymatgen.entries.computed_entries import \
GibbsComputedStructureEntry
from pymatgen.entries.computed_entries import GibbsComputedStructureEntry

entries = GibbsComputedStructureEntry.from_entries(entries, temp=use_gibbs)

if not self.monty_decode:
entries = [entry.as_dict() for entry in entries]

return entries

def get_bandstructure_by_material_id(
Expand All @@ -970,7 +1002,7 @@ def get_bandstructure_by_material_id(
Get the band structure pymatgen object associated with a Materials Project ID.
Arguments:
materials_id (str): Materials Project ID for a material
material_id (str): Materials Project ID for a material
path_type (BSPathType): k-point path selection convention
line_mode (bool): Whether to return data for a line-mode calculation
Expand All @@ -986,7 +1018,7 @@ def get_dos_by_material_id(self, material_id: str):
Get the complete density of states pymatgen object associated with a Materials Project ID.
Arguments:
materials_id (str): Materials Project ID for a material
material_id (str): Materials Project ID for a material
Returns:
dos (CompleteDos): CompleteDos object
Expand Down Expand Up @@ -1028,6 +1060,7 @@ def submit_structures(self, structures, public_name, public_email):
Args:
structures: A list of Structure objects
Returns:
?
"""
Expand Down Expand Up @@ -1077,8 +1110,10 @@ def get_charge_density_from_material_id(
"""

if not hasattr(self, "charge_density"):
raise MPRestError("boto3 not installed. "
"To query charge density data install the boto3 package.")
raise MPRestError(
"boto3 not installed. "
"To query charge density data install the boto3 package."
)

# TODO: really we want a recommended task_id for charge densities here
# this could potentially introduce an ambiguity
Expand Down
2 changes: 2 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
"_user_settings",
"_general_store",
"tasks",
"bonds",
"xas",
"elasticity",
"fermi",
"alloys",
"summary",
] # temp
Expand Down
Loading

0 comments on commit 72eafaa

Please sign in to comment.