Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate AWS open data buckets #815

Merged
merged 9 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Materials Project data.
"""

import gzip
import itertools
import json
import platform
Expand All @@ -13,9 +14,10 @@
from json import JSONDecodeError
from math import ceil
from os import environ
from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from urllib.parse import quote, urljoin

import boto3
import requests
from emmet.core.utils import jsanitize
from monty.json import MontyDecoder
Expand Down Expand Up @@ -55,6 +57,7 @@ def __init__(
endpoint: str = DEFAULT_ENDPOINT,
include_user_agent: bool = True,
session: Optional[requests.Session] = None,
s3_resource: Optional[Any] = None,
debug: bool = False,
monty_decode: bool = True,
use_document_model: bool = True,
Expand Down Expand Up @@ -108,6 +111,11 @@ def __init__(
else:
self._session = None # type: ignore

if s3_resource:
self._s3_resource = s3_resource
else:
self._s3_resource = None

self.document_model = (
api_sanitize(self.document_model) if self.document_model is not None else None # type: ignore
)
Expand All @@ -120,6 +128,12 @@ def session(self) -> requests.Session:
)
return self._session

@property
def s3_resource(self):
if not self._s3_resource:
self._s3_resource = boto3.resource("s3")
return self._s3_resource

@staticmethod
def _create_session(api_key, include_user_agent, headers):
session = requests.Session()
Expand Down Expand Up @@ -230,6 +244,26 @@ def _post_resource(
except RequestException as ex:
raise MPRestError(str(ex))

def _query_open_data(self, bucket: str, prefix: str, key: str) -> dict:
"""Query Materials Project AWS open data s3 buckets

Args:
bucket (str): Materials project bucket name
prefix (str): Full set of file prefixes
key (str): Key for file

Returns:
dict: MontyDecoded data
"""
ref = self.s3_resource.Object(bucket, f"{prefix}/{key}.json.gz") # type: ignore
bytes = ref.get()["Body"] # type: ignore

with gzip.GzipFile(fileobj=bytes) as gzipfile:
content = gzipfile.read()
result = MontyDecoder().decode(content)

return result

def _query_resource(
self,
criteria: Optional[Dict] = None,
Expand Down
8 changes: 6 additions & 2 deletions mp_api/client/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def api_sanitize(
fields_to_leave: list of strings for model fields as "model__name__.field"
"""
models = [
model for model in get_flat_models_from_model(pydantic_model) if issubclass(model, BaseModel)
model
for model in get_flat_models_from_model(pydantic_model)
if issubclass(model, BaseModel)
] # type: List[Type[BaseModel]]

fields_to_leave = fields_to_leave or []
Expand Down Expand Up @@ -92,7 +94,9 @@ def validate_monty(cls, v):
errors.append("@class")

if len(errors) > 0:
raise ValueError("Missing Monty seriailzation fields in dictionary: {errors}")
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)

return v
else:
Expand Down
15 changes: 12 additions & 3 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from os import environ
from typing import Dict, List, Literal, Optional, Union

from emmet.core.charge_density import ChgcarDataDoc
from emmet.core.electronic_structure import BSPathType
from emmet.core.mpid import MPID
from emmet.core.settings import EmmetSettings
from emmet.core.tasks import TaskDoc
from emmet.core.vasp.calc_types import CalcType
from packaging import version
from pymatgen.analysis.phase_diagram import PhaseDiagram
Expand Down Expand Up @@ -1235,14 +1235,23 @@ def get_charge_density_from_material_id(
task_ids = self.get_task_ids_associated_with_material_id(
material_id, calc_types=[CalcType.GGA_Static, CalcType.GGA_U_Static]
)
results: List[ChgcarDataDoc] = self.charge_density.search(task_ids=task_ids) # type: ignore
results: List[TaskDoc] = self.tasks.search(task_ids=task_ids, fields=["last_updated", "task_id"]) # type: ignore

if len(results) == 0:
return None

latest_doc = max(results, key=lambda x: x.last_updated)

chgcar = self.charge_density.get_charge_density_from_file_id(latest_doc.fs_id)
result = (
self.tasks._query_open_data(
bucket="materialsproject-parsed",
prefix="chgcars",
key=str(latest_doc.task_id),
)
or {}
)

chgcar = result.get("data", None)

if chgcar is None:
raise MPRestError(f"No charge density fetched for {material_id}.")
Expand Down
7 changes: 4 additions & 3 deletions mp_api/client/routes/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ def set_message(
title: str,
body: str,
type: MessageType = MessageType.generic,
authors: List[str] = [],
authors: List[str] = None,
): # pragma: no cover
"""Set user settings
"""
Set user settings

Args:
title: Message title
Expand All @@ -34,7 +35,7 @@ def set_message(
Raises:
MPRestError.
"""
d = {"title": title, "body": body, "type": type.value, "authors": authors}
d = {"title": title, "body": body, "type": type.value, "authors": authors or []}

return self._post_resource(body=d).get("data")

Expand Down
73 changes: 1 addition & 72 deletions mp_api/client/routes/materials/charge_density.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import zlib
from os import environ
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

import boto3
import msgpack
from botocore import UNSIGNED
from botocore.client import Config
from botocore.exceptions import ConnectionError
from emmet.core.charge_density import ChgcarDataDoc
from monty.serialization import MontyDecoder, dumpfn
from monty.serialization import dumpfn

from mp_api.client.core import BaseRester
from mp_api.client.core.utils import validate_ids
Expand Down Expand Up @@ -76,67 +69,3 @@ def search( # type: ignore
fields=["last_updated", "task_id", "fs_id"],
**query_params,
)

def get_charge_density_from_file_id(self, fs_id: str):
url_doc = self.get_data_by_id(fs_id)

if url_doc:
# The check below is performed to see if the client is being
# used by our internal AWS deployment. If it is, we pull charge
# density data from a private S3 bucket. Else, we pull data
# from public MinIO buckets.
if environ.get("AWS_EXECUTION_ENV", None) == "AWS_ECS_FARGATE":
if self.boto_resource is None:
self.boto_resource = self._get_s3_resource(
use_minio=False, unsigned=False
)

bucket, obj_prefix = self._extract_s3_url_info(url_doc, use_minio=False)

else:
try:
if self.boto_resource is None:
self.boto_resource = self._get_s3_resource()

bucket, obj_prefix = self._extract_s3_url_info(url_doc)

except ConnectionError:
self.boto_resource = self._get_s3_resource(use_minio=False)

bucket, obj_prefix = self._extract_s3_url_info(
url_doc, use_minio=False
)

r = self.boto_resource.Object(bucket, f"{obj_prefix}/{url_doc.fs_id}").get()["Body"] # type: ignore

packed_bytes = r.read()

packed_bytes = zlib.decompress(packed_bytes)
json_data = msgpack.unpackb(packed_bytes, raw=False)
chgcar = MontyDecoder().process_decoded(json_data["data"])

return chgcar

else:
return None

def _extract_s3_url_info(self, url_doc, use_minio: bool = True):
if use_minio:
url_list = url_doc.url.split("/")
bucket = url_list[3]
obj_prefix = url_list[4]
else:
url_list = url_doc.s3_url_prefix.split("/")
bucket = url_list[2].split(".")[0]
obj_prefix = url_list[3]

return (bucket, obj_prefix)

def _get_s3_resource(self, use_minio: bool = True, unsigned: bool = True):
resource = boto3.resource(
"s3",
endpoint_url="https://minio.materialsproject.org" if use_minio else None,
config=Config(signature_version=UNSIGNED) if unsigned else None,
)

return resource
24 changes: 6 additions & 18 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,9 @@ def get_bandstructure_from_task_id(self, task_id: str):
Returns:
bandstructure (BandStructure): BandStructure or BandStructureSymmLine object
"""
result = self._query_resource(
criteria={"task_id": task_id, "_all_fields": True},
suburl="object",
use_document_model=False,
num_chunks=1,
chunk_size=1,

result = self._query_open_data(
bucket="materialsproject-parsed", prefix="bandstructures", key=task_id
)

if result.get("data", None) is not None:
Expand Down Expand Up @@ -322,12 +319,7 @@ def get_bandstructure_from_material_id(
bs_obj = self.get_bandstructure_from_task_id(bs_task_id)

if bs_obj:
b64_bytes = base64.b64decode(bs_obj[0], validate=True)
packed_bytes = zlib.decompress(b64_bytes)
json_data = msgpack.unpackb(packed_bytes, raw=False)
data = MontyDecoder().process_decoded(json_data["data"])

return data
return bs_obj
else:
raise MPRestError("No band structure object found.")

Expand Down Expand Up @@ -432,12 +424,8 @@ def get_dos_from_task_id(self, task_id: str):
Returns:
bandstructure (CompleteDos): CompleteDos object
"""
result = self._query_resource(
criteria={"task_id": task_id, "_all_fields": True},
suburl="object",
use_document_model=False,
num_chunks=1,
chunk_size=1,
result = self._query_open_data(
bucket="materialsproject-parsed", prefix="dos", key=task_id
)

if result.get("data", None) is not None:
Expand Down
13 changes: 11 additions & 2 deletions tests/molecules/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

from mp_api.client.routes.molecules.summary import MoleculesSummaryRester

excluded_params = ["sort_fields", "chunk_size", "num_chunks", "all_fields", "fields", "exclude_elements"]
excluded_params = [
"sort_fields",
"chunk_size",
"num_chunks",
"all_fields",
"fields",
"exclude_elements",
]

alt_name = {"formula": "formula_alphabetical", "molecule_ids": "molecule_id"}

Expand All @@ -25,7 +32,9 @@
} # type: dict


@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
def test_client():
search_method = MoleculesSummaryRester().search

Expand Down
2 changes: 2 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
ignore_generic = [
"_user_settings",
"_general_store",
"_messages",
# "tasks",
# "bonds",
"materials_xas",
"materials_elasticity",
"materials_fermi",
"molecules_vibrations",
# "alloys",
# "summary",
] # temp
Expand Down
Loading