Skip to content

Commit

Permalink
Integrate AWS open data buckets (#815)
Browse files Browse the repository at this point in the history
* Pull BS, DOS, and CHGCAR from open data

* Linting

* Fix default arg in messages rester

* Small linting

* Add messages rester to tests

* Linting

* Add molecule vibrations to ignore rester list

* Linting
  • Loading branch information
munrojm committed Jun 14, 2023
1 parent f7e2a4e commit bf390c7
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 101 deletions.
36 changes: 35 additions & 1 deletion mp_api/client/core/client.py
Expand Up @@ -3,6 +3,7 @@
Materials Project data.
"""

import gzip
import itertools
import json
import platform
Expand All @@ -13,9 +14,10 @@
from json import JSONDecodeError
from math import ceil
from os import environ
from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from urllib.parse import quote, urljoin

import boto3
import requests
from emmet.core.utils import jsanitize
from monty.json import MontyDecoder
Expand Down Expand Up @@ -55,6 +57,7 @@ def __init__(
endpoint: str = DEFAULT_ENDPOINT,
include_user_agent: bool = True,
session: Optional[requests.Session] = None,
s3_resource: Optional[Any] = None,
debug: bool = False,
monty_decode: bool = True,
use_document_model: bool = True,
Expand Down Expand Up @@ -108,6 +111,11 @@ def __init__(
else:
self._session = None # type: ignore

if s3_resource:
self._s3_resource = s3_resource
else:
self._s3_resource = None

self.document_model = (
api_sanitize(self.document_model) if self.document_model is not None else None # type: ignore
)
Expand All @@ -120,6 +128,12 @@ def session(self) -> requests.Session:
)
return self._session

@property
def s3_resource(self):
if not self._s3_resource:
self._s3_resource = boto3.resource("s3")
return self._s3_resource

@staticmethod
def _create_session(api_key, include_user_agent, headers):
session = requests.Session()
Expand Down Expand Up @@ -230,6 +244,26 @@ def _post_resource(
except RequestException as ex:
raise MPRestError(str(ex))

def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:
"""Query Materials Project AWS open data s3 buckets
Args:
bucket (str): Materials project bucket name
prefix (str): Full set of file prefixes
key (str): Key for file
Returns:
dict: MontyDecoded data
"""
ref = self.s3_resource.Object(bucket, f"{prefix}/{key}.json.gz") # type: ignore
bytes = ref.get()["Body"] # type: ignore

with gzip.GzipFile(fileobj=bytes) as gzipfile:
content = gzipfile.read()
result = MontyDecoder().decode(content)

return result

def _query_resource(
self,
criteria: Optional[Dict] = None,
Expand Down
8 changes: 6 additions & 2 deletions mp_api/client/core/utils.py
Expand Up @@ -43,7 +43,9 @@ def api_sanitize(
fields_to_leave: list of strings for model fields as "model__name__.field"
"""
models = [
model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel)
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: List[Type[BaseModel]]

fields_to_leave = fields_to_leave or []
Expand Down Expand Up @@ -92,7 +94,9 @@ def validate_monty(cls, v):
errors.append("@class")

if len(errors) > 0:
raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}")
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)

return v
else:
Expand Down
15 changes: 12 additions & 3 deletions mp_api/client/mprester.py
Expand Up @@ -5,10 +5,10 @@
from os import environ
from typing import Dict, List, Literal, Optional, Union

from emmet.core.charge_density import ChgcarDataDoc
from emmet.core.electronic_structure import BSPathType
from emmet.core.mpid import MPID
from emmet.core.settings import EmmetSettings
from emmet.core.tasks import TaskDoc
from emmet.core.vasp.calc_types import CalcType
from packaging import version
from pymatgen.analysis.phase_diagram import PhaseDiagram
Expand Down Expand Up @@ -1235,14 +1235,23 @@ def get_charge_density_from_material_id(
task_ids = self.get_task_ids_associated_with_material_id(
material_id, calc_types=[CalcType.GGA_Static, CalcType.GGA_U_Static]
)
results: List[ChgcarDataDoc] = self.charge_density.search(task_ids=task_ids) # type: ignore
results: List[TaskDoc] = self.tasks.search(task_ids=task_ids, fields=["last_updated", "task_id"]) # type: ignore

if len(results) == 0:
return None

latest_doc = max(results, key=lambda x: x.last_updated)

chgcar = self.charge_density.get_charge_density_from_file_id(latest_doc.fs_id)
result = (
self.tasks._query_open_data(
bucket="materialsproject-parsed",
prefix="chgcars",
key=str(latest_doc.task_id),
)
or {}
)

chgcar = result.get("data", None)

if chgcar is None:
raise MPRestError(f"No charge density fetched for {material_id}.")
Expand Down
7 changes: 4 additions & 3 deletions mp_api/client/routes/_messages.py
Expand Up @@ -18,9 +18,10 @@ def set_message(
title: str,
body: str,
type: MessageType = MessageType.generic,
authors: List[str] = [],
authors: List[str] = None,
): # pragma: no cover
"""Set user settings
"""
Set user settings
Args:
title: Message title
Expand All @@ -34,7 +35,7 @@ def set_message(
Raises:
MPRestError.
"""
d = {"title": title, "body": body, "type": type.value, "authors": authors}
d = {"title": title, "body": body, "type": type.value, "authors": authors or []}

return self._post_resource(body=d).get("data")

Expand Down
73 changes: 1 addition & 72 deletions mp_api/client/routes/materials/charge_density.py
@@ -1,15 +1,8 @@
import zlib
from os import environ
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

import boto3
import msgpack
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ConnectionError
from emmet.core.charge_density import ChgcarDataDoc
from monty.serialization import MontyDecoder, dumpfn
from monty.serialization import dumpfn

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids
Expand Down Expand Up @@ -76,67 +69,3 @@ def search( # type: ignore
fields=["last_updated", "task_id", "fs_id"],
**query_params,
)

def get_charge_density_from_file_id(self, fs_id: str):
url_doc = self.get_data_by_id(fs_id)

if url_doc:
# The check below is performed to see if the client is being
# used by our internal AWS deployment. If it is, we pull charge
# density data from a private S3 bucket. Else, we pull data
# from public MinIO buckets.
if environ.get("AWS_EXECUTION_ENV", None) == "AWS_ECS_FARGATE":
if self.boto_resource is None:
self.boto_resource = self._get_s3_resource(
use_minio=False, unsigned=False
)

bucket, obj_prefix = self._extract_s3_url_info(url_doc, use_minio=False)

else:
try:
if self.boto_resource is None:
self.boto_resource = self._get_s3_resource()

bucket, obj_prefix = self._extract_s3_url_info(url_doc)

except ConnectionError:
self.boto_resource = self._get_s3_resource(use_minio=False)

bucket, obj_prefix = self._extract_s3_url_info(
url_doc, use_minio=False
)

r = self.boto_resource.Object(bucket, f"{obj_prefix}/{url_doc.fs_id}").get()["Body"] # type: ignore

packed_bytes = r.read()

packed_bytes = zlib.decompress(packed_bytes)
json_data = msgpack.unpackb(packed_bytes, raw=False)
chgcar = MontyDecoder().process_decoded(json_data["data"])

return chgcar

else:
return None

def _extract_s3_url_info(self, url_doc, use_minio: bool = True):
if use_minio:
url_list = url_doc.url.split("/")
bucket = url_list[3]
obj_prefix = url_list[4]
else:
url_list = url_doc.s3_url_prefix.split("/")
bucket = url_list[2].split(".")[0]
obj_prefix = url_list[3]

return (bucket, obj_prefix)

def _get_s3_resource(self, use_minio: bool = True, unsigned: bool = True):
resource = boto3.resource(
"s3",
endpoint_url="https://minio.materialsproject.org" if use_minio else None,
config=Config(signature_version=UNSIGNED) if unsigned else None,
)

return resource
24 changes: 6 additions & 18 deletions mp_api/client/routes/materials/electronic_structure.py
Expand Up @@ -249,12 +249,9 @@ def get_bandstructure_from_task_id(self, task_id: str):
Returns:
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
"""
result = self._query_resource(
criteria={"task_id": task_id, "_all_fields": True},
suburl="object",
use_document_model=False,
num_chunks=1,
chunk_size=1,

result = self._query_open_data(
bucket="materialsproject-parsed", prefix="bandstructures", key=task_id
)

if result.get("data", None) is not None:
Expand Down Expand Up @@ -322,12 +319,7 @@ def get_bandstructure_from_material_id(
bs_obj = self.get_bandstructure_from_task_id(bs_task_id)

if bs_obj:
b64_bytes = base64.b64decode(bs_obj[0], validate=True)
packed_bytes = zlib.decompress(b64_bytes)
json_data = msgpack.unpackb(packed_bytes, raw=False)
data = MontyDecoder().process_decoded(json_data["data"])

return data
return bs_obj
else:
raise MPRestError("No band structure object found.")

Expand Down Expand Up @@ -432,12 +424,8 @@ def get_dos_from_task_id(self, task_id: str):
Returns:
bandstructure (CompleteDos): CompleteDos object
"""
result = self._query_resource(
criteria={"task_id": task_id, "_all_fields": True},
suburl="object",
use_document_model=False,
num_chunks=1,
chunk_size=1,
result = self._query_open_data(
bucket="materialsproject-parsed", prefix="dos", key=task_id
)

if result.get("data", None) is not None:
Expand Down
13 changes: 11 additions & 2 deletions tests/molecules/test_summary.py
Expand Up @@ -6,7 +6,14 @@

from mp_api.client.routes.molecules.summary import MoleculesSummaryRester

excluded_params = ["sort_fields", "chunk_size", "num_chunks", "all_fields", "fields", "exclude_elements"]
excluded_params = [
"sort_fields",
"chunk_size",
"num_chunks",
"all_fields",
"fields",
"exclude_elements",
]

alt_name = {"formula": "formula_alphabetical", "molecule_ids": "molecule_id"}

Expand All @@ -25,7 +32,9 @@
} # type: dict


@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
def test_client():
search_method = MoleculesSummaryRester().search

Expand Down
2 changes: 2 additions & 0 deletions tests/test_client.py
Expand Up @@ -31,11 +31,13 @@
ignore_generic = [
"_user_settings",
"_general_store",
"_messages",
# "tasks",
# "bonds",
"materials_xas",
"materials_elasticity",
"materials_fermi",
"molecules_vibrations",
# "alloys",
# "summary",
] # temp
Expand Down

0 comments on commit bf390c7

Please sign in to comment.