Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/materialsproject/api.git; b…
Browse files Browse the repository at this point in the history
…ranch 'main' of https://github.com/materialsproject/api into main
  • Loading branch information
munrojm committed Apr 30, 2021
2 parents 9c3fc34 + dbcc2b8 commit b3a6ac7
Show file tree
Hide file tree
Showing 21 changed files with 164 additions and 448 deletions.
7 changes: 5 additions & 2 deletions app.py
Expand Up @@ -71,7 +71,7 @@
materials_store = MongoURIStore(
uri=f"mongodb+srv://{db_uri}",
database="mp_core",
key="task_id",
key="material_id",
collection_name=f"materials.core_{db_version}",
)

Expand Down Expand Up @@ -480,7 +480,10 @@
resources.update({"mpcomplete": mpcomplete_resource(mpcomplete_store)})

# Consumers
from mp_api.routes._consumer.resources import set_settings_resource, get_settings_resource
from mp_api.routes._consumer.resources import (
set_settings_resource,
get_settings_resource,
)

resources.update({"user_settings": get_settings_resource(consumer_settings_store)})
resources.update({"user_settings/set": set_settings_resource(consumer_settings_store)})
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Expand Up @@ -12,7 +12,7 @@ services:
- PORT=5001
- NUM_WORKERS=4
- MPCONTRIBS_MONGO_HOST=$MPCONTRIBS_MONGO_HOST
- DB_VERSION=2020_09_08
- DB_VERSION=2021_04_26
volumes:
- .:/app
ports:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -10,7 +10,7 @@ statistics = True
[flake8]
exclude = .git,__pycache__,docs_rst/conf.py,tests,pymatgen/io/abinit,__init__.py
# max-complexity = 10
extend-ignore = E741
extend-ignore = E741, F401, F403
max-line-length = 120

[pydocstyle]
Expand Down
53 changes: 37 additions & 16 deletions src/mp_api/core/client.py
Expand Up @@ -47,7 +47,7 @@ def __init__(
version=None,
include_user_agent=True,
session=None,
debug=False
debug=False,
):
"""
Args:
Expand Down Expand Up @@ -247,7 +247,7 @@ def _query_resource(
monty_decode: bool = True,
suburl: Optional[str] = None,
use_document_model: Optional[bool] = True,
version: Optional[str] = None
version: Optional[str] = None,
):
"""
Query the endpoint for a Resource containing a list of documents
Expand Down Expand Up @@ -337,7 +337,7 @@ def query(
fields: Optional[List[str]] = None,
monty_decode: bool = True,
suburl: Optional[str] = None,
version: Optional[str] = None
version: Optional[str] = None,
):
"""
Query the endpoint for a list of documents.
Expand All @@ -354,7 +354,11 @@ def query(
A list of documents
"""
return self._query_resource(
criteria=criteria, fields=fields, monty_decode=monty_decode, suburl=suburl, version=version
criteria=criteria,
fields=fields,
monty_decode=monty_decode,
suburl=suburl,
version=version,
).get("data")

def get_document_by_id(
Expand Down Expand Up @@ -412,12 +416,14 @@ def get_document_by_id(
else:
return results[0]

def search(self,
version: Optional[str] = None,
num_chunks: Optional[int] = None,
chunk_size: int = 1000,
fields: Optional[List[str]] = None,
**kwargs):
def search(
self,
version: Optional[str] = None,
num_chunks: Optional[int] = None,
chunk_size: int = 1000,
fields: Optional[List[str]] = None,
**kwargs,
):
"""
A generic search method to retrieve documents matching specific parameters.
Expand All @@ -436,11 +442,21 @@ def search(self,
# documented kwargs.

if not fields:
warnings.warn(f"No data fields requested. Choose from: {self.available_fields}")
warnings.warn(
f"No data fields requested. Choose from: {self.available_fields}"
)

return self._get_all_documents(kwargs, fields=fields, version=version, chunk_size=chunk_size, num_chunks=num_chunks)
return self._get_all_documents(
kwargs,
fields=fields,
version=version,
chunk_size=chunk_size,
num_chunks=num_chunks,
)

def _get_all_documents(self, query_params, fields=None, version=None, chunk_size=1000, num_chunks=None):
def _get_all_documents(
self, query_params, fields=None, version=None, chunk_size=1000, num_chunks=None
):
"""
Iterates over pages until all documents are retrieved. Displays
progress using tqdm. This method is designed to give a common
Expand All @@ -459,7 +475,10 @@ def _get_all_documents(self, query_params, fields=None, version=None, chunk_size
count = 1

# progress bar
t = tqdm(desc=f"Retrieving {self.document_model.__name__} documents", total=results["meta"]["total"])
t = tqdm(
desc=f"Retrieving {self.document_model.__name__} documents",
total=results["meta"]["total"],
)
t.update(len(all_results))

while True:
Expand All @@ -468,7 +487,9 @@ def _get_all_documents(self, query_params, fields=None, version=None, chunk_size

t.update(len(results["data"]))

if not any(results["data"]) or (num_chunks is not None and count == num_chunks):
if not any(results["data"]) or (
num_chunks is not None and count == num_chunks
):
break

count += 1
Expand Down Expand Up @@ -504,7 +525,7 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
def available_fields(self) -> List[str]:
if self.document_model is None:
return ["Unknown fields."]
return list(self.document_model.schema()['properties'].keys()) # type: ignore
return list(self.document_model.schema()["properties"].keys()) # type: ignore

def __repr__(self):
return f"<{self.__class__.__name__} {self.endpoint}>"
Expand Down
85 changes: 84 additions & 1 deletion src/mp_api/core/models.py
@@ -1,7 +1,10 @@
from datetime import datetime
from typing import Generic, List, Optional, TypeVar, Type
from monty.json import MSONable
from pydantic.utils import lenient_issubclass
from pydantic.schema import get_flat_models_from_model
from mp_api import __version__
from pydantic import BaseModel, Field, validator
from typing import Generic, TypeVar, Optional, List
from pydantic.generics import GenericModel


Expand Down Expand Up @@ -76,3 +79,83 @@ def default_meta(cls, v, values):
if "total" not in v and values.get("data", None) is not None:
v["total"] = len(values["data"])
return v


def api_sanitize(
pydantic_model: Type[BaseModel],
fields_to_leave: Optional[List[str]] = None,
allow_dict_msonable=False,
):
"""
Function to clean up pydantic models for the API by:
1.) Making fields optional
2.) Allowing dictionaries in-place of the objects for MSONable quantities
WARNING: This works in place, so it mutates the model and all sub-models
Args:
fields_to_leave: list of strings for model fields as "model__name__.field"
"""

models = [
model
for model in get_flat_models_from_model(pydantic_model)
if lenient_issubclass(model, BaseModel)
]

fields_to_leave = fields_to_leave or []
fields_tuples = [f.split(".") for f in fields_to_leave]
assert all(len(f) == 2 for f in fields_tuples)

for model in models:
model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]}
for name, field in model.__fields__.items(): # type: ignore
field_type = field.type_

if name not in model_fields_to_leave:
field.required = False
field.field_info.default = None

if (
field_type is not None
and lenient_issubclass(field_type, MSONable)
and allow_dict_msonable
):
field.type_ = allow_msonable_dict(field_type)
field.populate_validators()

return pydantic_model


def allow_msonable_dict(monty_cls: Type[MSONable]):
"""
Patch Monty to allow for dict values for MSONable
"""

def validate_monty(cls, v):
"""
Stub validator for MSONable as a dictionary only
"""
if isinstance(v, cls):
return v
elif isinstance(v, dict):
# Just validate the simple Monty Dict Model
errors = []
if v.get("@module", "") != monty_cls.__module__:
errors.append("@module")

if v.get("@class", "") != monty_cls.__name__:
errors.append("@class")

if len(errors) > 0:
raise ValueError(
"Missing Monty seriailzation fields in dictionary: {errors}"
)

return v
else:
raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary")

setattr(monty_cls, "validate_monty", classmethod(validate_monty))

return monty_cls
6 changes: 3 additions & 3 deletions src/mp_api/core/resource.py
Expand Up @@ -9,7 +9,7 @@

from maggma.core import Store

from mp_api.core.models import Response
from mp_api.core.models import Response, api_sanitize
from mp_api.core.utils import (
STORE_PARAMS,
merge_queries,
Expand Down Expand Up @@ -61,9 +61,9 @@ def __init__(
assert issubclass(
class_model, BaseModel
), "The resource model has to be a PyDantic Model"
self.model = class_model
self.model = api_sanitize(class_model, allow_dict_msonable=True)
elif isinstance(model, type) and issubclass(model, (BaseModel, MSONable)):
self.model = model
self.model = api_sanitize(model, allow_dict_msonable=True)
else:
raise ValueError("The resource model has to be a PyDantic Model")

Expand Down
2 changes: 1 addition & 1 deletion src/mp_api/matproj.py
Expand Up @@ -118,7 +118,7 @@ def get_structure_by_material_id(
Structure object.
"""
# TODO: decide about `final` and `conventional_unit_cell`
return self.materials.get_structure_by_material_id(material_id=material_id)
return self.materials.get_structure_by_material_id(material_id=material_id) # type: ignore

# @deprecated(self.materials.get_database_version, _DEPRECATION_WARNING)
def get_database_version(self):
Expand Down
2 changes: 1 addition & 1 deletion src/mp_api/routes/__init__.py
Expand Up @@ -22,4 +22,4 @@
from mp_api.routes.molecules.client import MoleculesRester
from mp_api.routes.synthesis.client import SynthesisRester
from mp_api.routes.electrodes.client import ElectrodeRester
from mp_api.routes.charge_density.client import ChargeDensityRester
from mp_api.routes.charge_density.client import ChargeDensityRester
2 changes: 1 addition & 1 deletion src/mp_api/routes/_consumer/client.py
@@ -1,4 +1,4 @@
from mp_api.routes._consumer import UserSettingsDoc
from mp_api.routes._consumer.models import UserSettingsDoc
from mp_api.core.client import BaseRester


Expand Down
49 changes: 4 additions & 45 deletions src/mp_api/routes/dielectric/models.py
@@ -1,59 +1,18 @@
from typing import List
from datetime import datetime
from monty.json import MontyDecoder

from pydantic import BaseModel, Field, validator


class DielectricData(BaseModel):
"""
Model for dielectric data in dielectric document
"""

total: List[List[float]] = Field(
None,
description="Total dielectric tensor",
)

ionic: List[List[float]] = Field(
None,
description="Ionic contribution to dielectric tensor",
)

static: List[List[float]] = Field(
None,
description="Electronic contribution to dielectric tensor",
)
from emmet.core.polar import Dielectric

e_total: float = Field(
None,
description="Total dielectric constant",
)

e_ionic: float = Field(
None,
description="Ionic contributio to dielectric constant",
)

e_static: float = Field(
None,
description="Electronic contribution to dielectric constant",
)

n: float = Field(
None,
description="Refractive index",
)
from pydantic import BaseModel, Field, validator


class DielectricDoc(BaseModel):
"""
Model for a document containing dielectric data
"""

dielectric: DielectricData = Field(
None,
description="Dielectric data",
dielectric: Dielectric = Field(
None, description="Dielectric data",
)

task_id: str = Field(
Expand Down
2 changes: 1 addition & 1 deletion src/mp_api/routes/electrodes/client.py
@@ -1,5 +1,5 @@
from mp_api.core.client import BaseRester
from mp_api.routes.electrodes.models import InsertionElectrodeDoc
from emmet.core.electrode import InsertionElectrodeDoc


class ElectrodeRester(BaseRester):
Expand Down

0 comments on commit b3a6ac7

Please sign in to comment.