From c58eec68f9770ce58da61b47fc6a5c092977282f Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Tue, 11 May 2021 13:44:23 -0700 Subject: [PATCH] Molecules client update --- src/mp_api/routes/molecules/client.py | 55 +++++++------------------- src/mp_api/routes/molecules/client.pyi | 13 ++++++ 2 files changed, 27 insertions(+), 41 deletions(-) create mode 100644 src/mp_api/routes/molecules/client.pyi diff --git a/src/mp_api/routes/molecules/client.py b/src/mp_api/routes/molecules/client.py index d405f6ca..29ae9189 100644 --- a/src/mp_api/routes/molecules/client.py +++ b/src/mp_api/routes/molecules/client.py @@ -4,7 +4,7 @@ from pymatgen.core.periodic_table import Element -from mp_api.core.client import BaseRester, MPRestError +from mp_api.core.client import BaseRester from mp_api.routes.molecules.models import MoleculesDoc @@ -13,24 +13,6 @@ class MoleculesRester(BaseRester): suffix = "molecules" document_model = MoleculesDoc # type: ignore - def get_molecule_from_molecule_id(self, molecule_id: str): - """ - Get molecule data for a given Materials Project molecule ID. - - Arguments: - molecule_id (str): Materials project molecule ID - - Returns: - results (Dict): Dictionary containing molecule data. - """ - - result = self._make_request("{}/?all_fields=true".format(molecule_id)) - - if len(result.get("data", [])) > 0: - return result - else: - raise MPRestError("No document found") - def search_molecules_docs( self, elements: Optional[List[Element]] = None, @@ -41,7 +23,8 @@ def search_molecules_docs( pointgroup: Optional[str] = None, smiles: Optional[str] = None, num_chunks: Optional[int] = None, - chunk_size: int = 100, + chunk_size: int = 1000, + all_fields: bool = True, fields: Optional[List[str]] = None, ): """ @@ -57,20 +40,16 @@ def search_molecules_docs( smiles (str): The simplified molecular input line-entry system (SMILES) representation of the molecule. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in SubstratesDoc to return data for. Default is the film_id and substrate_id only. - Yields: - ([dict]) List of dictionaries containing data for entries defined in 'fields'. - Defaults to Materials Project molecule IDs. + Returns: + ([MoleculesDoc]) List of molecule documents """ query_params = defaultdict(dict) # type: dict - if chunk_size <= 0 or chunk_size > 100: - warnings.warn("Improper chunk size given. Setting value to 100.") - chunk_size = 100 - if elements: query_params.update({"elements": ",".join([str(ele) for ele in elements])}) @@ -94,23 +73,17 @@ def search_molecules_docs( if charge: query_params.update({"charge_min": charge[0], "charge_max": charge[1]}) - if fields: - query_params.update({"fields": ",".join(fields)}) - query_params = { entry: query_params[entry] for entry in query_params if query_params[entry] is not None } - query_params.update({"limit": chunk_size, "skip": 0}) - count = 0 - while True: - query_params["skip"] = count * chunk_size - results = self.query(query_params).get("data", []) - - if not any(results) or (num_chunks is not None and count == num_chunks): - break - - count += 1 - yield results + return super().search( + version=self.version, + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params + ) diff --git a/src/mp_api/routes/molecules/client.pyi b/src/mp_api/routes/molecules/client.pyi new file mode 100644 index 00000000..da5e5106 --- /dev/null +++ b/src/mp_api/routes/molecules/client.pyi @@ -0,0 +1,13 @@ +from typing import List, Optional +from mp_api.routes.molecules.models import MoleculesDoc + + +class MoleculeRester: + + def get_document_by_id( + self, + document_id: str, + fields: Optional[List[str]] = None, + monty_decode: bool = True, + version: Optional[str] = None, + ) -> MoleculesDoc: ...