Skip to content

Commit

Permalink
Piezo client update
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed May 11, 2021
1 parent 6d592a1 commit ee0d297
Showing 1 changed file with 11 additions and 33 deletions.
44 changes: 11 additions & 33 deletions src/mp_api/routes/piezo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,14 @@
class PiezoRester(BaseRester):

suffix = "piezoelectric"
document_model = PiezoDoc

def get_piezo_from_material_id(self, material_id: str):
"""
Get piezoelectric data for a given Materials Project ID.
Arguments:
material_id (str): Materials project ID
Returns:
results (Dict): Dictionary containing piezoelectric data.
"""

result = self._make_request("{}/?all_fields=true".format(material_id))

if len(result.get("data", [])) > 0:
return result
else:
raise MPRestError("No document found")
document_model = PiezoDoc # type: ignore

def search_piezoelectric_docs(
self,
piezoelectric_modulus: Optional[Tuple[float, float]] = None,
num_chunks: Optional[int] = None,
chunk_size: int = 100,
all_fields: bool = True,
fields: Optional[List[str]] = None,
):
"""
Expand All @@ -45,6 +28,7 @@ def search_piezoelectric_docs(
piezoelectric modulus in C/m² to consider.
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in EOSDoc to return data for.
Default is material_id only.
Expand All @@ -67,23 +51,17 @@ def search_piezoelectric_docs(
}
)

if fields:
query_params.update({"fields": ",".join(fields)})

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

query_params.update({"limit": chunk_size, "skip": 0})
count = 0
while True:
query_params["skip"] = count * chunk_size
results = self.query(query_params).get("data", [])

if not any(results) or (num_chunks is not None and count == num_chunks):
break

count += 1
yield results
return super().search(
version=self.version,
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
)

0 comments on commit ee0d297

Please sign in to comment.