From ee0d297706183ed70ba342c150a69288417c51a3 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Tue, 11 May 2021 13:54:07 -0700 Subject: [PATCH] Piezo client update --- src/mp_api/routes/piezo/client.py | 44 ++++++++----------------------- 1 file changed, 11 insertions(+), 33 deletions(-) diff --git a/src/mp_api/routes/piezo/client.py b/src/mp_api/routes/piezo/client.py index 16ce3456..dce4d7b4 100644 --- a/src/mp_api/routes/piezo/client.py +++ b/src/mp_api/routes/piezo/client.py @@ -10,31 +10,14 @@ class PiezoRester(BaseRester): suffix = "piezoelectric" - document_model = PiezoDoc - - def get_piezo_from_material_id(self, material_id: str): - """ - Get piezoelectric data for a given Materials Project ID. - - Arguments: - material_id (str): Materials project ID - - Returns: - results (Dict): Dictionary containing piezoelectric data. - """ - - result = self._make_request("{}/?all_fields=true".format(material_id)) - - if len(result.get("data", [])) > 0: - return result - else: - raise MPRestError("No document found") + document_model = PiezoDoc # type: ignore def search_piezoelectric_docs( self, piezoelectric_modulus: Optional[Tuple[float, float]] = None, num_chunks: Optional[int] = None, chunk_size: int = 100, + all_fields: bool = True, fields: Optional[List[str]] = None, ): """ @@ -45,6 +28,7 @@ def search_piezoelectric_docs( piezoelectric modulus in C/m² to consider. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in EOSDoc to return data for. Default is material_id only. @@ -67,23 +51,17 @@ def search_piezoelectric_docs( } ) - if fields: - query_params.update({"fields": ",".join(fields)}) - query_params = { entry: query_params[entry] for entry in query_params if query_params[entry] is not None } - query_params.update({"limit": chunk_size, "skip": 0}) - count = 0 - while True: - query_params["skip"] = count * chunk_size - results = self.query(query_params).get("data", []) - - if not any(results) or (num_chunks is not None and count == num_chunks): - break - - count += 1 - yield results + return super().search( + version=self.version, + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params + )