Skip to content

Commit

Permalink
ES client docstring update
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed Jun 29, 2021
1 parent 53a710c commit 9ff5f69
Showing 1 changed file with 66 additions and 18 deletions.
84 changes: 66 additions & 18 deletions src/mp_api/routes/electronic_structure/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def search_electronic_structure_docs(
query_params = defaultdict(dict) # type: dict

if band_gap:
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -72,10 +74,18 @@ def search_electronic_structure_docs(
if ascending is not None:
query_params.update({"ascending": ascending})

query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

return super().search(
num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
)


Expand Down Expand Up @@ -126,7 +136,9 @@ def search_bandstructure_summary(
query_params["path_type"] = path_type.value

if band_gap:
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -146,10 +158,18 @@ def search_bandstructure_summary(
if ascending is not None:
query_params.update({"ascending": ascending})

query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

return super().search(
num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
)

def get_bandstructure_from_calculation_id(self, task_id: str):
Expand All @@ -163,7 +183,9 @@ def get_bandstructure_from_calculation_id(self, task_id: str):
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
"""

result = self._query_resource(criteria={"task_id": task_id}, suburl="object", use_document_model=False)
result = self._query_resource(
criteria={"task_id": task_id}, suburl="object", use_document_model=False
)

if result.get("data", None) is not None:
return result["data"]
Expand All @@ -184,14 +206,22 @@ def get_bandstructure_from_material_id(
bandstructure (BandStructureSymmLine): BandStructureSymmLine object
"""

es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key)
es_rester = ElectronicStructureRester(
endpoint=self.base_endpoint, api_key=self.api_key
)

bs_data = es_rester.get_document_by_id(document_id=material_id, fields=["bandstructure"]).bandstructure.dict()
bs_data = es_rester.get_document_by_id(
document_id=material_id, fields=["bandstructure"]
).bandstructure.dict()

if bs_data[path_type.value]:
bs_calc_id = bs_data[path_type.value]["task_id"]
else:
raise MPRestError("No {} band structure data found for {}".format(path_type.value, material_id))
raise MPRestError(
"No {} band structure data found for {}".format(
path_type.value, material_id
)
)

bs_obj = self.get_bandstructure_from_calculation_id(bs_calc_id)

Expand Down Expand Up @@ -257,7 +287,9 @@ def search_dos_summary(
query_params["orbital"] = orbital.value

if band_gap:
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -271,10 +303,18 @@ def search_dos_summary(
if is_metal is not None:
query_params.update({"is_metal": is_metal})

query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}
query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

return super().search(
num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, **query_params
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
)

def get_dos_from_calculation_id(self, task_id: str):
Expand All @@ -288,7 +328,9 @@ def get_dos_from_calculation_id(self, task_id: str):
bandstructure (CompleteDos): CompleteDos object
"""

result = self._query_resource(criteria={"task_id": task_id}, suburl="object", use_document_model=False)
result = self._query_resource(
criteria={"task_id": task_id}, suburl="object", use_document_model=False
)

if result.get("data", None) is not None:
return result["data"]
Expand All @@ -306,18 +348,24 @@ def get_dos_from_material_id(self, material_id: str):
dos (CompleteDos): CompleteDos object
"""

es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key)
es_rester = ElectronicStructureRester(
endpoint=self.base_endpoint, api_key=self.api_key
)

dos_data = es_rester.get_document_by_id(document_id=material_id, fields=["dos"]).dict()
dos_data = es_rester.get_document_by_id(
document_id=material_id, fields=["dos"]
).dict()

if dos_data["dos"]:
dos_calc_id = dos_data["dos"]["total"]["1"]["task_id"]
else:
raise MPRestError("No density of states data found for {}".format(material_id))
raise MPRestError(
"No density of states data found for {}".format(material_id)
)

dos_obj = self.get_dos_from_calculation_id(dos_calc_id)

if dos_obj:
return dos_obj[0]["data"]
else:
raise MPRestError("No band structure object found.")
raise MPRestError("No density of states object found.")

0 comments on commit 9ff5f69

Please sign in to comment.