Skip to content

Commit

Permalink
Enhance usability (#506)
Browse files Browse the repository at this point in the history
* Add parallel requests to initial page

* Factor our multithread handling

* Fix MPRester import

* Fix pagination with doc num tracking

* Bug fixes pagination and enhance use of threading

* Fix no meta in initial page result

* Ensure initial lone critieria has limit

* Fix synthesis client search

* Ensure limit in robocrys query

* Remove query and sort input params

* Privatize search method

* Rename rester search methods

* Update rester tests

* Remove test prints

* Sort rest of input params in search methods

* Query method error

* Dyanmic construction of returned model

* Progress bar mute

* Fix entry comparison

* Remove endpoint field and add superclass to repr

* Remove endpoint from new fields

* Fix new dyanamic model repr

* Ensure multiple synthesis keywords search properly

* Add num_elements to all search methods

* Incorporate formual list queries in search methods

* Thermo formula and working ion to list

* Rename and deprecate get_materials methods

* Add deprecated docstrings

* Add to query not implemented message

* Linting

* Fix summary test

* Update some tests

* Electrode client bug fix

* Fix charge density test

* Temporarily revert document post-processing

* Fix comment

* Comment linting

* Properly deprecate old search methods

* Linting

* Fix coverage

* Oxidation state rester fix
  • Loading branch information
munrojm committed Jun 3, 2022
1 parent 865f33f commit 250f51c
Show file tree
Hide file tree
Showing 52 changed files with 1,466 additions and 889 deletions.
277 changes: 57 additions & 220 deletions src/mp_api/client.py

Large diffs are not rendered by default.

62 changes: 57 additions & 5 deletions src/mp_api/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from matplotlib import use
from monty.json import MontyDecoder
from mp_api.core.settings import MAPIClientSettings
from pydantic import BaseModel, create_model
from mp_api.core.utils import validate_ids
from pydantic import BaseModel
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
from tqdm.auto import tqdm
Expand Down Expand Up @@ -567,14 +567,18 @@ def _multi_thread(
with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor:

# Get list of initial futures defined by max number of parallel requests
futures = set({})
for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS):
futures = set()

for params in itertools.islice(
params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS
):

future = executor.submit(
self._submit_request_and_process,
use_document_model=use_document_model,
**params,
)

setattr(future, "crit_ind", params_ind)
futures.add(future)
params_ind += 1
Expand All @@ -584,18 +588,22 @@ def _multi_thread(
finished, futures = wait(futures, return_when=FIRST_COMPLETED)

for future in finished:

data, subtotal = future.result()

if progress_bar is not None:
progress_bar.update(len(data["data"]))
return_data.append((data, subtotal, future.crit_ind)) # type: ignore

# Populate more futures to replace finished
for params in itertools.islice(params_gen, len(finished)):

new_future = executor.submit(
self._submit_request_and_process,
use_document_model=use_document_model,
**params,
)

setattr(new_future, "crit_ind", params_ind)
futures.add(new_future)
params_ind += 1
Expand Down Expand Up @@ -630,7 +638,12 @@ def _submit_request_and_process(
# other sub-urls may use different document models
# the client does not handle this in a particularly smart way currently
if self.document_model and use_document_model:
data["data"] = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore
raw_doc_list = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

# Temporarily removed until user-testing completed
# data["data"] = self._generate_returned_model(raw_doc_list)

data["data"] = raw_doc_list

meta_total_doc_num = data.get("meta", {}).get("total_doc", 1)

Expand All @@ -654,6 +667,45 @@ def _submit_request_and_process(
f"on URL {response.url} with message:\n{message}"
)

def _generate_returned_model(self, data):

new_data = []

for doc in data:
set_data = {
field: value
for field, value in doc
if field in doc.dict(exclude_unset=True)
}
unset_fields = [field for field in doc.__fields__ if field not in set_data]

data_model = create_model(
"MPDataEntry",
fields_not_requested=unset_fields,
__base__=self.document_model,
)

data_model.__fields__ = {
**{
name: description
for name, description in data_model.__fields__.items()
if name in set_data
},
"fields_not_requested": data_model.__fields__["fields_not_requested"],
}

def new_repr(self) -> str:
extra = ", ".join(
f"{n}={getattr(self, n)!r}" for n in data_model.__fields__
)
return f"{self.__class__.__name__}<{self.__class__.__base__.__name__}>({extra})"

data_model.__repr__ = new_repr

new_data.append(data_model(**set_data))

return new_data

def _query_resource_data(
self,
criteria: Optional[Dict] = None,
Expand Down Expand Up @@ -752,7 +804,7 @@ def get_data_by_id(
else:
return results[0]

def search(
def _search(
self,
num_chunks: Optional[int] = None,
chunk_size: int = 1000,
Expand Down
35 changes: 25 additions & 10 deletions src/mp_api/routes/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,35 @@
from mp_api.core.client import BaseRester
from emmet.core.bonds import BondingDoc

import warnings


class BondsRester(BaseRester[BondingDoc]):

suffix = "bonds"
document_model = BondingDoc # type: ignore
primary_key = "material_id"

def search_bonds_docs(
def search_bonds_docs(self, *args, **kwargs): # pragma: no cover
"""
Deprecated
"""

warnings.warn(
"MPRester.bonds.search_bonds_docs is deprecated. Please use MPRester.bonds.search instead.",
DeprecationWarning,
stacklevel=2,
)

return self.search(*args, **kwargs)

def search(
self,
max_bond_length: Optional[Tuple[float, float]] = None,
min_bond_length: Optional[Tuple[float, float]] = None,
mean_bond_length: Optional[Tuple[float, float]] = None,
coordination_envs: Optional[List[str]] = None,
coordination_envs_anonymous: Optional[List[str]] = None,
max_bond_length: Optional[Tuple[float, float]] = None,
mean_bond_length: Optional[Tuple[float, float]] = None,
min_bond_length: Optional[Tuple[float, float]] = None,
sort_fields: Optional[List[str]] = None,
num_chunks: Optional[int] = None,
chunk_size: int = 1000,
Expand All @@ -28,15 +43,15 @@ def search_bonds_docs(
Query bonding docs using a variety of search criteria.
Arguments:
coordination_envs (List[str]): List of coordination environments to consider (e.g. ['Mo-S(6)', 'S-Mo(3)']).
coordination_envs_anonymous (List[str]): List of anonymous coordination environments to consider
(e.g. ['A-B(6)', 'A-B(3)']).
max_bond_length (Tuple[float,float]): Minimum and maximum value for the maximum bond length
in the structure to consider.
min_bond_length (Tuple[float,float]): Minimum and maximum value for the minimum bond length
in the structure to consider.
mean_bond_length (Tuple[float,float]): Minimum and maximum value for the mean bond length
in the structure to consider.
coordination_envs (List[str]): List of coordination environments to consider (e.g. ['Mo-S(6)', 'S-Mo(3)']).
coordination_envs_anonymous (List[str]): List of anonymous coordination environments to consider
(e.g. ['A-B(6)', 'A-B(3)']).
min_bond_length (Tuple[float,float]): Minimum and maximum value for the minimum bond length
in the structure to consider.
sort_fields (List[str]): Fields used to sort results. Prefixing with '-' will sort in descending order.
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
Expand Down Expand Up @@ -93,7 +108,7 @@ def search_bonds_docs(
if query_params[entry] is not None
}

return super().search(
return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
Expand Down
2 changes: 1 addition & 1 deletion src/mp_api/routes/charge_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def search( # type: ignore
A list of ChgcarDataDoc that contain task_id references.
"""

return super().search(
return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=False,
Expand Down
19 changes: 17 additions & 2 deletions src/mp_api/routes/dielectric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,29 @@
from emmet.core.polar import DielectricDoc
from mp_api.core.client import BaseRester

import warnings


class DielectricRester(BaseRester[DielectricDoc]):

suffix = "dielectric"
document_model = DielectricDoc # type: ignore
primary_key = "material_id"

def search_dielectric_docs(
def search_dielectric_docs(self, *args, **kwargs): # pragma: no cover
"""
Deprecated
"""

warnings.warn(
"MPRester.dielectric.search_dielectric_docs is deprecated. Please use MPRester.dielectric.search instead.",
DeprecationWarning,
stacklevel=2,
)

return self.search(*args, **kwargs)

def search(
self,
e_total: Optional[Tuple[float, float]] = None,
e_ionic: Optional[Tuple[float, float]] = None,
Expand Down Expand Up @@ -72,7 +87,7 @@ def search_dielectric_docs(
if query_params[entry] is not None
}

return super().search(
return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
Expand Down
7 changes: 7 additions & 0 deletions src/mp_api/routes/doi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,10 @@ class DOIRester(BaseRester[DOIDoc]):
suffix = "doi"
document_model = DOIDoc # type: ignore
primary_key = "task_id"

def search(*args, **kwargs): # pragma: no cover
raise NotImplementedError(
"""
The DOIRester.search method does not exist as no search endpoint is present. Use get_data_by_id instead.
"""
)
43 changes: 29 additions & 14 deletions src/mp_api/routes/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,37 @@
from emmet.core.elasticity import ElasticityDoc
from mp_api.core.client import BaseRester

import warnings


class ElasticityRester(BaseRester[ElasticityDoc]):

suffix = "elasticity"
document_model = ElasticityDoc # type: ignore
primary_key = "task_id"

def search_elasticity_docs(
def search_elasticity_docs(self, *args, **kwargs): # pragma: no cover
"""
Deprecated
"""

warnings.warn(
"MPRester.elasticity.search_elasticity_docs is deprecated. Please use MPRester.elasticity.search instead.",
DeprecationWarning,
stacklevel=2,
)

return self.search(*args, **kwargs)

def search(
self,
k_voigt: Optional[Tuple[float, float]] = None,
k_reuss: Optional[Tuple[float, float]] = None,
k_vrh: Optional[Tuple[float, float]] = None,
elastic_anisotropy: Optional[Tuple[float, float]] = None,
g_voigt: Optional[Tuple[float, float]] = None,
g_reuss: Optional[Tuple[float, float]] = None,
g_vrh: Optional[Tuple[float, float]] = None,
elastic_anisotropy: Optional[Tuple[float, float]] = None,
k_voigt: Optional[Tuple[float, float]] = None,
k_reuss: Optional[Tuple[float, float]] = None,
k_vrh: Optional[Tuple[float, float]] = None,
poisson_ratio: Optional[Tuple[float, float]] = None,
sort_fields: Optional[List[str]] = None,
num_chunks: Optional[int] = None,
Expand All @@ -31,20 +46,20 @@ def search_elasticity_docs(
Query elasticity docs using a variety of search criteria.
Arguments:
k_voigt (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Voigt average of the bulk modulus.
k_reuss (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Reuss average of the bulk modulus.
k_vrh (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Voigt-Reuss-Hill average of the bulk modulus.
elastic_anisotropy (Tuple[float,float]): Minimum and maximum value to consider for
the elastic anisotropy.
g_voigt (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Voigt average of the shear modulus.
g_reuss (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Reuss average of the shear modulus.
g_vrh (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Voigt-Reuss-Hill average of the shear modulus.
elastic_anisotropy (Tuple[float,float]): Minimum and maximum value to consider for
the elastic anisotropy.
k_voigt (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Voigt average of the bulk modulus.
k_reuss (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Reuss average of the bulk modulus.
k_vrh (Tuple[float,float]): Minimum and maximum value in GPa to consider for
the Voigt-Reuss-Hill average of the bulk modulus.
poisson_ratio (Tuple[float,float]): Minimum and maximum value to consider for
Poisson's ratio.
sort_fields (List[str]): Fields used to sort results. Prefix with '-' to sort in descending order.
Expand Down Expand Up @@ -102,7 +117,7 @@ def search_elasticity_docs(
if query_params[entry] is not None
}

return super().search(
return super()._search(
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
Expand Down
Loading

0 comments on commit 250f51c

Please sign in to comment.