Skip to content

Commit

Permalink
Fix de-serlization of dos object
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed Jun 14, 2023
1 parent bf390c7 commit 8f0358a
Showing 1 changed file with 23 additions and 74 deletions.
97 changes: 23 additions & 74 deletions mp_api/client/routes/materials/electronic_structure.py
Expand Up @@ -109,9 +109,7 @@ def search(
query_params.update({"exclude_elements": ",".join(exclude_elements)})

if band_gap:
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -122,9 +120,7 @@ def search(
if num_elements:
if isinstance(num_elements, int):
num_elements = (num_elements, num_elements)
query_params.update(
{"nelements_min": num_elements[0], "nelements_max": num_elements[1]}
)
query_params.update({"nelements_min": num_elements[0], "nelements_max": num_elements[1]})

if is_gap_direct is not None:
query_params.update({"is_gap_direct": is_gap_direct})
Expand All @@ -133,15 +129,9 @@ def search(
query_params.update({"is_metal": is_metal})

if sort_fields:
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}

return super()._search(
num_chunks=num_chunks,
Expand Down Expand Up @@ -205,9 +195,7 @@ def search(
query_params["path_type"] = path_type.value

if band_gap:
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -222,15 +210,9 @@ def search(
query_params.update({"is_metal": is_metal})

if sort_fields:
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}

return super()._search(
num_chunks=num_chunks,
Expand All @@ -250,9 +232,7 @@ def get_bandstructure_from_task_id(self, task_id: str):
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
"""

result = self._query_open_data(
bucket="materialsproject-parsed", prefix="bandstructures", key=task_id
)
result = self._query_open_data(bucket="materialsproject-parsed", prefix="bandstructures", key=task_id)

if result.get("data", None) is not None:
return result["data"]
Expand All @@ -275,46 +255,32 @@ def get_bandstructure_from_material_id(
Returns:
bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object
"""
es_rester = ElectronicStructureRester(
endpoint=self.base_endpoint, api_key=self.api_key
)
es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key)

if line_mode:
bs_data = es_rester.get_data_by_id(
document_id=material_id, fields=["bandstructure"]
).bandstructure
bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["bandstructure"]).bandstructure

if bs_data is None:
raise MPRestError(
f"No {path_type.value} band structure data found for {material_id}"
)
raise MPRestError(f"No {path_type.value} band structure data found for {material_id}")
else:
bs_data = bs_data.dict()

if bs_data.get(path_type.value, None):
bs_task_id = bs_data[path_type.value]["task_id"]
else:
raise MPRestError(
f"No {path_type.value} band structure data found for {material_id}"
)
raise MPRestError(f"No {path_type.value} band structure data found for {material_id}")
else:
bs_data = es_rester.get_data_by_id(
document_id=material_id, fields=["dos"]
).dos
bs_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dos

if bs_data is None:
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)
raise MPRestError(f"No uniform band structure data found for {material_id}")
else:
bs_data = bs_data.dict()

if bs_data.get("total", None):
bs_task_id = bs_data["total"]["1"]["task_id"]
else:
raise MPRestError(
f"No uniform band structure data found for {material_id}"
)
raise MPRestError(f"No uniform band structure data found for {material_id}")

bs_obj = self.get_bandstructure_from_task_id(bs_task_id)

Expand Down Expand Up @@ -386,9 +352,7 @@ def search(
query_params["orbital"] = orbital.value

if band_gap:
query_params.update(
{"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]}
)
query_params.update({"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]})

if efermi:
query_params.update({"efermi_min": efermi[0], "efermi_max": efermi[1]})
Expand All @@ -397,15 +361,9 @@ def search(
query_params.update({"magnetic_ordering": magnetic_ordering.value})

if sort_fields:
query_params.update(
{"_sort_fields": ",".join([s.strip() for s in sort_fields])}
)
query_params.update({"_sort_fields": ",".join([s.strip() for s in sort_fields])})

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}
query_params = {entry: query_params[entry] for entry in query_params if query_params[entry] is not None}

return super()._search(
num_chunks=num_chunks,
Expand All @@ -424,9 +382,7 @@ def get_dos_from_task_id(self, task_id: str):
Returns:
bandstructure (CompleteDos): CompleteDos object
"""
result = self._query_open_data(
bucket="materialsproject-parsed", prefix="dos", key=task_id
)
result = self._query_open_data(bucket="materialsproject-parsed", prefix="dos", key=task_id)

if result.get("data", None) is not None:
return result["data"]
Expand All @@ -442,25 +398,18 @@ def get_dos_from_material_id(self, material_id: str):
Returns:
dos (CompleteDos): CompleteDos object
"""
es_rester = ElectronicStructureRester(
endpoint=self.base_endpoint, api_key=self.api_key
)
es_rester = ElectronicStructureRester(endpoint=self.base_endpoint, api_key=self.api_key)

dos_data = es_rester.get_data_by_id(
document_id=material_id, fields=["dos"]
).dict()
dos_data = es_rester.get_data_by_id(document_id=material_id, fields=["dos"]).dict()

if dos_data["dos"]:
dos_task_id = dos_data["dos"]["total"]["1"]["task_id"]
else:
raise MPRestError(f"No density of states data found for {material_id}")

dos_obj = self.get_dos_from_task_id(dos_task_id)

if dos_obj:
b64_bytes = base64.b64decode(dos_obj[0], validate=True)
packed_bytes = zlib.decompress(b64_bytes)
json_data = msgpack.unpackb(packed_bytes, raw=False)
data = MontyDecoder().process_decoded(json_data["data"])
return data
return dos_obj
else:
raise MPRestError("No density of states object found.")

0 comments on commit 8f0358a

Please sign in to comment.