Skip to content

Commit

Permalink
Update chemenv documetation and suggestions (#771)
Browse files Browse the repository at this point in the history
* Update documetation and suggestions

* Add chemenv fields and proper typing

* Add None check to chemenv var

* Fix spelling error in doc

* Add chemenv rester test

* Linting

* Chemenv var fix

* Fix tests

* Linting

---------

Co-authored-by: Jason Munro <jmunro@lbl.gov>
Co-authored-by: Jason Munro <jason.munro@gmail.com>
  • Loading branch information
3 people committed May 6, 2023
1 parent 68d54df commit aa93193
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 23 deletions.
84 changes: 61 additions & 23 deletions mp_api/client/routes/chemenv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from collections import defaultdict
from typing import List, Optional, Tuple, Union

from emmet.core.chemenv import ChemEnvDoc
from emmet.core.chemenv import (
COORDINATION_GEOMETRIES,
COORDINATION_GEOMETRIES_IUCR,
COORDINATION_GEOMETRIES_IUPAC,
COORDINATION_GEOMETRIES_NAMES,
ChemEnvDoc,
)

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids
Expand All @@ -15,9 +21,21 @@ class ChemenvRester(BaseRester[ChemEnvDoc]):
def search(
self,
material_ids: Optional[Union[str, List[str]]] = None,
chemenv_iucr: Optional[Union[str, List[str]]] = None,
chemenv_iupac: Optional[Union[str, List[str]]] = None,
chemenv_name: Optional[Union[str, List[str]]] = None,
chemenv_iucr: Optional[
Union[COORDINATION_GEOMETRIES_IUCR, List[COORDINATION_GEOMETRIES_IUCR]]
] = None,
chemenv_iupac: Optional[
Union[COORDINATION_GEOMETRIES_IUPAC, List[COORDINATION_GEOMETRIES_IUPAC]]
] = None,
chemenv_name: Optional[
Union[COORDINATION_GEOMETRIES_NAMES, List[COORDINATION_GEOMETRIES_NAMES]]
] = None,
chemenv_symbol: Optional[
Union[COORDINATION_GEOMETRIES, List[COORDINATION_GEOMETRIES]]
] = None,
species: Optional[Union[str, List[str]]] = None,
elements: Optional[Union[str, List[str]]] = None,
exclude_elements: Optional[List[str]] = None,
csm: Optional[Tuple[float, float]] = None,
density: Optional[Tuple[float, float]] = None,
num_elements: Optional[Tuple[int, int]] = None,
Expand All @@ -28,15 +46,23 @@ def search(
chunk_size: int = 1000,
all_fields: bool = True,
fields: Optional[List[str]] = None,
) -> List[ChemEnvDoc]:
):
"""Query for chemical environment data.
Arguments:
material_ids (str, List[str]): Search forchemical environment associated with the specified Material IDs.
chemenv_iucr (str, List[str]): Unique cationic species in IUCR format.
chemenv_iupac (str, List[str]): Unique cationic species in IUPAC format.
chemenv_iupac (str, List[str]): Coordination environment descriptions for unique cationic species.
density (Tuple[float,float]): Minimum and maximum value of continuous symmetry measure to consider.
chemenv_iucr (COORDINATION_GEOMETRIES_IUCR, List[COORDINATION_GEOMETRIES_IUCR]): Unique cationic species in
IUCR format, e.g. "[3n]".
chemenv_iupac (COORDINATION_GEOMETRIES_IUPAC, List[COORDINATION_GEOMETRIES_IUPAC]): Unique cationic species
in IUPAC format, e.g., "T-4".
chemenv_name (COORDINATION_GEOMETRIES_NAMES, List[COORDINATION_GEOMETRIES_NAMES]): Coordination environment
descriptions in text form for unique cationic species, e.g. "Tetrahedron".
chemenv_symbol (COORDINATION_GEOMETRIES, List[COORDINATION_GEOMETRIES]): Coordination environment
descriptions as used in ChemEnv package for unique cationic species, e.g. "T:4".
species (str, List[str]): Cationic species in the crystal structure, e.g. "Ti4+".
elements (str, List[str]): Element names in the crystal structure, e.g., "Ti".
exclude_elements (List[str]): A list of elements to exclude.
csm (Tuple[float,float]): Minimum and maximum value of continuous symmetry measure to consider.
density (Tuple[float,float]): Minimum and maximum density to consider.
num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider.
num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider.
Expand Down Expand Up @@ -66,6 +92,12 @@ def search(
{"nsites_min": num_sites[0], "nsites_max": num_sites[1]}
)

if elements:
query_params.update({"elements": ",".join(elements)})

if exclude_elements:
query_params.update({"exclude_elements": ",".join(exclude_elements)})

if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
Expand All @@ -79,23 +111,29 @@ def search(

query_params.update({"material_ids": ",".join(validate_ids(material_ids))})

if chemenv_iucr:
if isinstance(chemenv_iucr, str):
chemenv_iucr = [chemenv_iucr]

query_params.update({"chemenv_iucr": ",".join(chemenv_iucr)})
chemenv_literals = {
"chemenv_iucr": (chemenv_iucr, COORDINATION_GEOMETRIES_IUCR),
"chemenv_iupac": (chemenv_iupac, COORDINATION_GEOMETRIES_IUPAC),
"chemenv_name": (chemenv_name, COORDINATION_GEOMETRIES_NAMES),
"chemenv_symbol": (chemenv_symbol, COORDINATION_GEOMETRIES),
}

if chemenv_iupac:
if isinstance(chemenv_iupac, str):
chemenv_iupac = [chemenv_iupac]
for chemenv_var_name, (chemenv_var, literals) in chemenv_literals.items():
if chemenv_var:
t_types = {t if isinstance(t, str) else t.value for t in chemenv_var}
valid_types = {*map(str, literals.__args__)}
if invalid_types := t_types - valid_types:
raise ValueError(
f"Invalid type(s) passed for {chemenv_var_name}: {invalid_types}, valid types are: {valid_types}"
)

query_params.update({"chemenv_iupac": ",".join(chemenv_iupac)})
query_params.update({chemenv_var_name: ",".join(t_types)})

if chemenv_name:
if isinstance(chemenv_name, str):
chemenv_name = [chemenv_name]
if species:
if isinstance(species, str):
species = [species]

query_params.update({"chemenv_name": ",".join(chemenv_name)})
query_params.update({"species": ",".join(species)})

if sort_fields:
query_params.update(
Expand All @@ -113,5 +151,5 @@ def search(
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
**query_params,
)
98 changes: 98 additions & 0 deletions tests/test_chemenv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import typing

import pytest

from mp_api.client.routes.chemenv import ChemenvRester


@pytest.fixture
def rester():
rester = ChemenvRester()
yield rester
rester.session.close()


excluded_params = [
"sort_fields",
"chunk_size",
"num_chunks",
"all_fields",
"fields",
"volume",
]

sub_doc_fields = [] # type: list

alt_name_dict = {
"material_ids": "material_id",
"exclude_elements": "material_id",
"num_elements": "nelements",
"num_sites": "nsites",
} # type: dict

custom_field_tests = {
"material_ids": ["mp-22526"],
"elements": ["Si", "O"],
"exclude_elements": ["Si", "O"],
"chemenv_symbol": ["S:1"],
"chemenv_iupac": ["IC-12"],
"chemenv_iucr": ["[2l]"],
"chemenv_name": ["Octahedron"],
"species": ["Cu2+"],
} # type: dict


@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

if search_method is not None:
# Get list of parameters
param_tuples = list(typing.get_type_hints(search_method).items())

# Query API for each numeric and boolean parameter and check if returned
for entry in param_tuples:
param = entry[0]
if param not in excluded_params:
param_type = entry[1].__args__[0]
q = None

if param_type == typing.Tuple[int, int]:
project_field = alt_name_dict.get(param, None)
q = {
param: (-100, 100),
"chunk_size": 1,
"num_chunks": 1,
}
elif param_type == typing.Tuple[float, float]:
project_field = alt_name_dict.get(param, None)
q = {
param: (-1.12, 1.12),
"chunk_size": 1,
"num_chunks": 1,
}
elif param_type is bool:
project_field = alt_name_dict.get(param, None)
q = {
param: False,
"chunk_size": 1,
"num_chunks": 1,
}
elif param in custom_field_tests:
project_field = alt_name_dict.get(param, None)
q = {
param: custom_field_tests[param],
"chunk_size": 1,
"num_chunks": 1,
}
doc = search_method(**q)[0].dict()

for sub_field in sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]

assert (
doc[project_field if project_field is not None else param]
is not None
)

0 comments on commit aa93193

Please sign in to comment.