Skip to content

Commit

Permalink
Allow thermo_type string passing for phase diagram method (#742)
Browse files Browse the repository at this point in the history
* Allow string inputs for thermo type in pd method

* Fix type hint

* Linting

* Fix exclude element testts

* Update ele tests

* Exclude nsites

* Revert test change

* Fix set
  • Loading branch information
munrojm committed Feb 10, 2023
1 parent 9dc27dc commit 2ec98bd
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 38 deletions.
8 changes: 3 additions & 5 deletions mp_api/client/mprester.py
Expand Up @@ -169,7 +169,7 @@ def __init__(
session=self.session,
monty_decode=monty_decode,
use_document_model=use_document_model,
headers=self.headers
headers=self.headers,
) # type: BaseRester

self._all_resters.append(rester)
Expand Down Expand Up @@ -710,9 +710,7 @@ def get_ion_reference_data(self) -> List[Dict]:
compounds and aqueous species, Wiley, New York (1978)'}}
"""
return self.contribs.query_contributions(
query={"project": "ion_ref_data"},
fields=["identifier", "formula", "data"],
paginate=True
query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True
).get("data")

def get_ion_reference_data_for_chemsys(self, chemsys: Union[str, List]) -> List[Dict]:
Expand Down Expand Up @@ -926,7 +924,7 @@ def get_entries_in_chemsys(
correspond to proper function inputs to `MPRester.thermo.search`. For instance,
if you are only interested in entries on the convex hull, you could pass
{"energy_above_hull": (0.0, 0.0)} or {"is_stable": True}, or if you are only interested
in entry data
in entry data
Returns:
List of ComputedStructureEntries.
"""
Expand Down
31 changes: 13 additions & 18 deletions mp_api/client/routes/thermo.py
Expand Up @@ -24,8 +24,7 @@ def search_thermo_docs(self, *args, **kwargs): # pragma: no cover
"""

warnings.warn(
"MPRester.thermo.search_thermo_docs is deprecated. "
"Please use MPRester.thermo.search instead.",
"MPRester.thermo.search_thermo_docs is deprecated. " "Please use MPRester.thermo.search instead.",
DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -109,25 +108,19 @@ def search(
t_types = {t if isinstance(t, str) else t.value for t in thermo_types}
valid_types = {*map(str, ThermoType.__members__.values())}
if invalid_types := t_types - valid_types:
raise ValueError(
f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}"
)
raise ValueError(f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}")
query_params.update({"thermo_types": ",".join(t_types)})

if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
query_params.update(
{"nelements_min": num_elements[0], "nelements_max": num_elements[1]}
)
query_params.update({"nelements_min": num_elements[0], "nelements_max": num_elements[1]})

if is_stable is not None:
query_params.update({"is_stable": is_stable})

if sort_fields:
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})

name_dict = {
"total_energy": "energy_per_atom",
Expand All @@ -146,11 +139,7 @@ def search(
}
)

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}

return super()._search(
num_chunks=num_chunks,
Expand All @@ -161,7 +150,7 @@ def search(
)

def get_phase_diagram_from_chemsys(
self, chemsys: str, thermo_type: ThermoType = ThermoType.GGA_GGA_U
self, chemsys: str, thermo_type: Union[ThermoType, str] = ThermoType.GGA_GGA_U
) -> PhaseDiagram:
"""
Get a pre-computed phase diagram for a given chemsys.
Expand All @@ -173,8 +162,14 @@ def get_phase_diagram_from_chemsys(
Returns:
phase_diagram (PhaseDiagram): Pymatgen phase diagram object.
"""

t_type = thermo_type if isinstance(thermo_type, str) else thermo_type.value
valid_types = {*map(str, ThermoType.__members__.values())}
if invalid_types := {t_type} - valid_types:
raise ValueError(f"Invalid thermo type(s) passed: {invalid_types}, valid types are: {valid_types}")

sorted_chemsys = "-".join(sorted(chemsys.split("-")))
phase_diagram_id = f"{sorted_chemsys}_{thermo_type.value}"
phase_diagram_id = f"{sorted_chemsys}_{t_type}"
response = self._query_resource(
fields=["phase_diagram"],
suburl=f"phase_diagram/{phase_diagram_id}",
Expand Down
11 changes: 3 additions & 8 deletions tests/test_materials.py
Expand Up @@ -19,6 +19,7 @@ def rester():
"num_chunks",
"all_fields",
"fields",
"exclude_elements", # temp until timeout update
]

sub_doc_fields = [] # type: list
Expand All @@ -39,17 +40,14 @@ def rester():
"formula": "Si",
"chemsys": "Si-O",
"elements": ["Si", "O"],
"exclude_elements": ["Si"],
"task_ids": ["mp-149"],
"crystal_system": CrystalSystem.cubic,
"spacegroup_number": 38,
"spacegroup_symbol": "Amm2",
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down Expand Up @@ -98,7 +96,4 @@ def test_client(rester):
if sub_field in doc:
doc = doc[sub_field]

assert (
doc[project_field if project_field is not None else param]
is not None
)
assert doc[project_field if project_field is not None else param] is not None
10 changes: 3 additions & 7 deletions tests/test_summary.py
Expand Up @@ -14,6 +14,7 @@
"all_fields",
"fields",
"equilibrium_reaction_energy", # temp until data update
"exclude_elements", # temp until data update
]

alt_name_dict = {
Expand Down Expand Up @@ -42,7 +43,6 @@
"formula": "SiO2",
"chemsys": "Si-O",
"elements": ["Si", "O"],
"exclude_elements": ["Si"],
"possible_species": ["O2-"],
"crystal_system": CrystalSystem.cubic,
"spacegroup_number": 38,
Expand All @@ -54,9 +54,7 @@
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
def test_client():

search_method = SummaryRester().search
Expand Down Expand Up @@ -107,6 +105,4 @@ def test_client():
else:
raise ValueError("No documents returned")

assert (
doc[project_field if project_field is not None else param] is not None
)
assert doc[project_field if project_field is not None else param] is not None

0 comments on commit 2ec98bd

Please sign in to comment.