diff --git a/app.py b/app.py index 6b7d56ef..67664513 100644 --- a/app.py +++ b/app.py @@ -1,5 +1,6 @@ import os from monty.serialization import loadfn + from mp_api.core.api import MAPI from mp_api.core.settings import MAPISettings @@ -242,7 +243,7 @@ s3_bs_index = MongoURIStore( uri=f"mongodb+srv://{db_uri}", database="mp_core", - key="fs_id", + key="task_id", collection_name="s3_bandstructure_index", ) @@ -257,7 +258,7 @@ index=s3_bs_index, bucket="mp-bandstructures", compress=True, - key="fs_id", + key="task_id", searchable_fields=["task_id", "fs_id"], ) @@ -272,7 +273,7 @@ s3_chgcar_index = MongoURIStore( uri=f"mongodb+srv://{db_uri}", database="mp_core", - key="fs_id", + key="task_id", collection_name="atomate_chgcar_fs_index", ) @@ -281,14 +282,14 @@ bucket="mp-volumetric", sub_dir="atomate_chgcar_fs/", compress=True, - key="fs_id", + key="task_id", searchable_fields=["task_id", "fs_id"], ) mpcomplete_store = MongoURIStore( uri=f"mongodb+srv://{db_uri}", database="mp_consumers", - key="snl_id", + key="submission_id", collection_name="mpcomplete", ) @@ -339,159 +340,195 @@ consumer_settings_store = loadfn(consumer_settings_store_json) # Materials -from mp_api.routes.materials.resources import materials_resource +from mp_api.routes.materials.resources import ( + materials_resource, + find_structure_resource, + formula_autocomplete_resource, +) resources.update( - {"materials": materials_resource(materials_store, formula_autocomplete_store)} + { + "materials": [ + find_structure_resource(materials_store), + formula_autocomplete_resource(formula_autocomplete_store), + materials_resource(materials_store), + ] + } ) -# Tasks -from mp_api.routes.tasks.resources import task_resource - -resources.update({"tasks": task_resource(task_store)}) +# resources.update({"find_structure": find_structure_resource(materials_store)}) -# Task Deprecation -from mp_api.routes.tasks.resources import task_deprecation_resource - -resources.update({"deprecation": task_deprecation_resource(materials_store)}) - -# Trajectory -from mp_api.routes.tasks.resources import trajectory_resource +# Tasks +from mp_api.routes.tasks.resources import ( + task_resource, + trajectory_resource, + task_deprecation_resource, +) -resources.update({"trajectory": trajectory_resource(task_store)}) +resources.update( + { + "tasks": [ + trajectory_resource(task_store), + task_deprecation_resource(materials_store), + task_resource(task_store), + ] + } +) # Thermo from mp_api.routes.thermo.resources import thermo_resource -resources.update({"thermo": thermo_resource(thermo_store)}) +resources.update({"thermo": [thermo_resource(thermo_store)]}) # Dielectric from mp_api.routes.dielectric.resources import dielectric_resource -resources.update({"dielectric": dielectric_resource(dielectric_piezo_store)}) +resources.update({"dielectric": [dielectric_resource(dielectric_piezo_store)]}) # Magnetism from mp_api.routes.magnetism.resources import magnetism_resource -resources.update({"magnetism": magnetism_resource(magnetism_store)}) +resources.update({"magnetism": [magnetism_resource(magnetism_store)]}) # Piezoelectric from mp_api.routes.piezo.resources import piezo_resource -resources.update({"piezoelectric": piezo_resource(dielectric_piezo_store)}) +resources.update({"piezoelectric": [piezo_resource(dielectric_piezo_store)]}) # Phonon from mp_api.routes.phonon.resources import phonon_bs_resource, phonon_img_resource -resources.update({"phonon": phonon_bs_resource(phonon_bs_store)}) -resources.update({"phonon_img": phonon_img_resource(phonon_img_store)}) +resources.update( + { + "phonon": [ + phonon_img_resource(phonon_img_store), + phonon_bs_resource(phonon_bs_store), + ] + } +) # EOS from mp_api.routes.eos.resources import eos_resource -resources.update({"eos": eos_resource(eos_store)}) +resources.update({"eos": [eos_resource(eos_store)]}) # Similarity from mp_api.routes.similarity.resources import similarity_resource -resources.update({"similarity": similarity_resource(similarity_store)}) +resources.update({"similarity": [similarity_resource(similarity_store)]}) # XAS from mp_api.routes.xas.resources import xas_resource -resources.update({"xas": xas_resource(xas_store)}) +resources.update({"xas": [xas_resource(xas_store)]}) # Grain Boundaries from mp_api.routes.grain_boundary.resources import gb_resource -resources.update({"grain_boundary": gb_resource(gb_store)}) +resources.update({"grain_boundary": [gb_resource(gb_store)]}) # Fermi Surface from mp_api.routes.fermi.resources import fermi_resource -resources.update({"fermi": fermi_resource(fermi_store)}) +resources.update({"fermi": [fermi_resource(fermi_store)]}) # Elasticity from mp_api.routes.elasticity.resources import elasticity_resource -resources.update({"elasticity": elasticity_resource(elasticity_store)}) +resources.update({"elasticity": [elasticity_resource(elasticity_store)]}) # DOIs from mp_api.routes.dois.resources import dois_resource -resources.update({"doi": dois_resource(doi_store)}) +resources.update({"doi": [dois_resource(doi_store)]}) # Substrates from mp_api.routes.substrates.resources import substrates_resource -resources.update({"substrates": substrates_resource(substrates_store)}) +resources.update({"substrates": [substrates_resource(substrates_store)]}) # Surface Properties from mp_api.routes.surface_properties.resources import surface_props_resource -resources.update({"surface_properties": surface_props_resource(surface_props_store)}) +resources.update({"surface_properties": [surface_props_resource(surface_props_store)]}) # Wulff from mp_api.routes.wulff.resources import wulff_resource -resources.update({"wulff": wulff_resource(wulff_store)}) +resources.update({"wulff": [wulff_resource(wulff_store)]}) # Robocrystallographer -from mp_api.routes.robocrys.resources import robo_resource +from mp_api.routes.robocrys.resources import robo_resource, robo_search_resource -resources.update({"robocrys": robo_resource(robo_store)}) +resources.update( + {"robocrys": [robo_search_resource(robo_store), robo_resource(robo_store)]} +) # Synthesis -from mp_api.routes.synthesis.resources import synth_resource +from mp_api.routes.synthesis.resources import synth_resource, synth_search_resource -resources.update({"synthesis": synth_resource(synth_store)}) +resources.update( + {"synthesis": [synth_resource(synth_store), synth_search_resource(synth_store)]} +) # Electrodes from mp_api.routes.electrodes.resources import insertion_electrodes_resource resources.update( - {"insertion_electrodes": insertion_electrodes_resource(insertion_electrodes_store)} + { + "insertion_electrodes": [ + insertion_electrodes_resource(insertion_electrodes_store) + ] + } ) # Molecules from mp_api.routes.molecules.resources import molecules_resource -resources.update({"molecules": molecules_resource(molecules_store)}) +resources.update({"molecules": [molecules_resource(molecules_store)]}) # Charge Density from mp_api.routes.charge_density.resources import charge_density_resource -resources.update({"charge_density": charge_density_resource(s3_chgcar)}) +resources.update({"charge_density": [charge_density_resource(s3_chgcar)]}) # Search -from mp_api.routes.search.resources import search_resource +from mp_api.routes.search.resources import search_resource, search_stats_resource -resources.update({"search": search_resource(search_store)}) +resources.update( + {"search": [search_stats_resource(search_store), search_resource(search_store)]} +) # Electronic Structure from mp_api.routes.electronic_structure.resources import ( + dos_obj_resource, es_resource, bs_resource, + bs_obj_resource, dos_resource, + dos_obj_resource, ) -resources.update({"electronic_structure": es_resource(es_store)}) -resources.update({"bandstructure": bs_resource(es_store, s3_bs)}) -resources.update({"dos": dos_resource(es_store, s3_dos)}) - +resources.update( + { + "electronic_structure": [ + bs_resource(es_store), + dos_resource(es_store), + es_resource(es_store), + bs_obj_resource(s3_bs), + dos_obj_resource(s3_dos), + ] + } +) # MPComplete from mp_api.routes.mpcomplete.resources import mpcomplete_resource -resources.update({"mpcomplete": mpcomplete_resource(mpcomplete_store)}) +resources.update({"mpcomplete": [mpcomplete_resource(mpcomplete_store)]}) # Consumers -from mp_api.routes._consumer.resources import ( - set_settings_resource, - get_settings_resource, -) +from mp_api.routes._consumer.resources import settings_resource -resources.update({"user_settings": get_settings_resource(consumer_settings_store)}) -resources.update({"user_settings/set": set_settings_resource(consumer_settings_store)}) +resources.update({"user_settings": [settings_resource(consumer_settings_store)]}) api = MAPI(resources=resources, debug=debug) diff --git a/requirements.txt b/requirements.txt index 7a64c6ab..5f5782a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ pydantic==1.8.2 pymatgen==2022.0.8 typing-extensions==3.10.0.0 -maggma==0.28.0 +maggma==0.29.0 requests==2.25.1 monty==2021.5.9 -emmet-core @ git+https://github.com/materialsproject/emmet#subdirectory=emmet-core +emmet-core==0.2.2 diff --git a/src/mp_api/core/api.py b/src/mp_api/core/api.py index 502dc4f1..1a8a66e1 100644 --- a/src/mp_api/core/api.py +++ b/src/mp_api/core/api.py @@ -1,77 +1,43 @@ -import os -import uvicorn -from starlette.responses import RedirectResponse -from fastapi import FastAPI -from typing import Dict, Union -from datetime import datetime -from monty.json import MSONable -from mp_api.core.resource import ConsumerPostResource, GetResource -from mp_api.core.settings import MAPISettings +from typing import Dict, List +from maggma.api.resource.core import Resource from pymatgen.core import __version__ as pmg_version # type: ignore +from mp_api import __version__ as api_version from fastapi.openapi.utils import get_openapi -from fastapi.middleware.cors import CORSMiddleware +from maggma.api.API import API -class MAPI(MSONable): +class MAPI(API): """ Core Materials API that orchestrates resources together - - TODO: - Build cross-resource relationships - Global Query Operators? """ def __init__( self, - resources: Dict[str, Union[GetResource, ConsumerPostResource]], + resources: Dict[str, List[Resource]], title="Materials Project API", - version="3.0.0-dev", + version="v3.0-dev", debug=False, + heartbeat_meta={"pymatgen": pmg_version}, ): - self.resources = resources - self.title = title - self.version = version - self.debug = debug + super().__init__( + resources=resources, + title=title, + version=version, + debug=debug, + heartbeat_meta=heartbeat_meta, + ) @property def app(self): """ App server for the cluster manager """ - - # TODO this should run on `not self.debug`! - on_startup = [resource.setup_indexes for resource in self.resources.values()] if self.debug else [] - app = FastAPI(title=self.title, version=self.version, on_startup=on_startup) - if self.debug: - app.add_middleware( - CORSMiddleware, allow_origins=["*"], allow_methods=["GET"], allow_headers=["*"], - ) - - if len(self.resources) == 0: - raise RuntimeError("ERROR: There are no resources provided") - - for prefix, endpoint in self.resources.items(): - app.include_router(endpoint.router, prefix=f"/{prefix}") - - @app.get("/heartbeat", include_in_schema=False) - def heartbeat(): - """ API Heartbeat for Load Balancing """ - - return { - "status": "OK", - "time": datetime.utcnow(), - "api": self.version, - "database": os.environ.get("DB_VERSION", MAPISettings().db_version).replace("_", "."), - "pymatgen": pmg_version, - } - - @app.get("/", include_in_schema=False) - def redirect_docs(): - """ Redirects the root end point to the docs """ - return RedirectResponse(url=app.docs_url, status_code=301) + app = super().app def custom_openapi(): - openapi_schema = get_openapi(title=self.title, version=self.version, routes=app.routes) + openapi_schema = get_openapi( + title=self.title, version=self.version, routes=app.routes + ) openapi_schema["components"]["securitySchemes"] = { "ApiKeyAuth": { @@ -94,19 +60,3 @@ def custom_openapi(): app.openapi = custom_openapi return app - - def run(self, ip: str = "127.0.0.1", port: int = 8000, log_level: str = "info"): - """ - Runs the Cluster Manager locally - Meant for debugging purposes only - - Args: - ip: Local IP to listen on - port: Local port to listen on - log_level: Logging level for the webserver - - Returns: - None - """ - - uvicorn.run(self.app, host=ip, port=port, log_level=log_level, reload=False) diff --git a/src/mp_api/core/client.py b/src/mp_api/core/client.py index 3f5db969..7b249ad7 100644 --- a/src/mp_api/core/client.py +++ b/src/mp_api/core/client.py @@ -20,7 +20,7 @@ from pydantic import BaseModel from tqdm import tqdm -from mp_api.core.utils import api_sanitize +from maggma.api.utils import api_sanitize try: from pymatgen.core import __version__ as pmg_version # type: ignore diff --git a/src/mp_api/core/models.py b/src/mp_api/core/models.py deleted file mode 100644 index 0d09f8cf..00000000 --- a/src/mp_api/core/models.py +++ /dev/null @@ -1,78 +0,0 @@ -from datetime import datetime -from typing import Generic, List, Optional, TypeVar -from mp_api import __version__ -from pydantic import BaseModel, Field, validator -from pydantic.generics import GenericModel - - -""" Describes the Materials API Response """ - - -DataT = TypeVar("DataT") - - -class Meta(BaseModel): - - """ - Meta information for the MAPI Response - """ - - api_version: str = Field( - __version__, - description="a string containing the version of the Materials API " - "implementation, e.g. v0.9.5", - ) - - time_stamp: datetime = Field( - None, - description="a string containing the date and time at which the query was executed", - ) - - @validator("time_stamp", pre=True, always=True) - def default_timestamp(cls, v): - return v or datetime.utcnow() - - class Config: - extra = "allow" - - -class Error(BaseModel): - """ - Base Error model for Materials API - """ - - code: int = Field(..., description="The error code") - message: str = Field(..., description="The description of the error") - - @classmethod - def from_traceback(cls, traceback): - pass - - -class Response(GenericModel, Generic[DataT]): - """ - A Materials API Response - """ - - data: Optional[List[DataT]] = Field(None, description="List of returned data") - errors: Optional[List[Error]] = Field( - None, description="Any errors on processing this query" - ) - meta: Optional[Meta] = Field(None, description="Extra information for the query") - - @validator("errors", always=True) - def check_consistency(cls, v, values): - if v is not None and values["data"] is not None: - raise ValueError("must not provide both data and error") - if v is None and values.get("data") is None: - raise ValueError("must provide data or error") - return v - - @validator("meta", pre=True, always=True) - def default_meta(cls, v, values): - if v is None: - v = Meta().dict() - else: - if "total" not in v and values.get("data", None) is not None: - v["total"] = len(values["data"]) - return v diff --git a/src/mp_api/core/query_operator.py b/src/mp_api/core/query_operator.py deleted file mode 100644 index 7d710e75..00000000 --- a/src/mp_api/core/query_operator.py +++ /dev/null @@ -1,223 +0,0 @@ -import inspect -from typing import List, Dict, Optional, Tuple -from pydantic import BaseModel -from fastapi import Query, HTTPException -from monty.json import MSONable -from maggma.core import Store -from mp_api.core.utils import STORE_PARAMS, dynamic_import - - -class QueryOperator(MSONable): - """ - Base Query Operator class for defining powerful query language - in the Materials API - """ - - def query(self) -> STORE_PARAMS: - """ - The query function that does the work for this query operator - """ - raise NotImplementedError("Query operators must implement query") - - def meta(self, store: Store, query: Dict) -> Dict: - """ - Returns meta data to return with the Response - - Args: - store: the Maggma Store that the resource uses - query: the query being executed in this API call - """ - return {} - - def post_process(self, docs: List[Dict]) -> List[Dict]: - """ - An optional post-processing function for the data - """ - return docs - - def ensure_indexes(self) -> List[Tuple[str, bool]]: - """ - Returns tuples of keys other than the store index, - and whether they are unique, to be used to setup indexes - in MongoDB - """ - return [] - - def _keys_from_query(self) -> List[str]: - """ - Method to extract parameters from query method to be used as index keys - """ - return [key for key in inspect.signature(self.query).parameters] - - -class PaginationQuery(QueryOperator): - """Query opertators to provides Pagination in the Materials API""" - - def __init__( - self, default_skip: int = 0, default_limit: int = 10, max_limit: int = 10000 - ): - """ - Args: - default_skip: the default number of documents to skip - default_limit: the default number of documents to return - max_limit: max number of documents to return - """ - self.default_skip = default_skip - self.default_limit = default_limit - self.max_limit = max_limit - - def query( - skip: int = Query( - default_skip, description="Number of entries to skip in the search" - ), - limit: int = Query( - default_limit, - description="Max number of entries to return in a single query." - f" Limited to {max_limit}", - ), - ) -> STORE_PARAMS: - """ - Pagination parameters for the API Endpoint - """ - if limit <= 0 or limit > max_limit: - limit = max_limit - return {"skip": skip, "limit": limit} - - setattr(self, "query", query) - - def meta(self, store: Store, query: Dict) -> Dict: - """ - Metadata for the pagination params - - Args: - store: the Maggma Store that the resource uses - query: the query being executed in this API call - """ - - count = store.count(query) - return {"max_limit": self.max_limit, "total": count} - - -class SparseFieldsQuery(QueryOperator): - """ - Factory function to generate a dependency for sparse field sets in FastAPI - """ - - def __init__(self, model: BaseModel, default_fields: Optional[List[str]] = None): - """ - Args: - model: PyDantic Model that represents the underlying data source - default_fields: default fields to return in the API response if no fields are explicitly requested - """ - - self.model = model - - model_name = self.model.__name__ # type: ignore - model_fields = list(self.model.__fields__.keys()) - - self.default_fields = ( - model_fields if default_fields is None else list(default_fields) - ) - - def query( - fields: str = Query( - None, - description=f"Fields to project from {str(model_name)} as a list of comma seperated strings", - ), - all_fields: bool = Query(False, description="Include all fields."), - ) -> STORE_PARAMS: - """ - Pagination parameters for the API Endpoint - """ - - properties = ( - fields.split(",") if isinstance(fields, str) else self.default_fields - ) - if all_fields: - properties = model_fields - - return {"properties": properties} - - setattr(self, "query", query) - - def meta(self, store: Store, query: Dict) -> Dict: - """ - Returns metadata for the Sparse field set - - Args: - store: the Maggma Store that the resource uses - query: the query being executed in this API call - """ - return {"default_fields": self.default_fields} - - def as_dict(self) -> Dict: - """ - Special as_dict implemented to convert pydantic models into strings - """ - - d = super().as_dict() # Ensures sub-classes serialize correctly - d["model"] = f"{self.model.__module__}.{self.model.__name__}" # type: ignore - return d - - @classmethod - def from_dict(cls, d): - - model = d.get("model") - if isinstance(model, str): - module_path = ".".join(model.split(".")[:-1]) - class_name = model.split(".")[-1] - model = dynamic_import(module_path, class_name) - - assert issubclass( - model, BaseModel - ), "The resource model has to be a PyDantic Model" - d["model"] = model - - return cls(**d) - - -class VersionQuery(QueryOperator): - """ - Method to generate a query on a specific collection version - """ - - def query( - self, - version: Optional[str] = Query( - None, description="Database version to query on formatted as YYYY.MM.DD", - ), - ) -> STORE_PARAMS: - - crit = {} - - if version: - crit.update({"version": version}) - - return {"criteria": crit} - - -class SortQuery(QueryOperator): - """ - Method to generate the sorting portion of a query - """ - - def query( - self, - field: Optional[str] = Query(None, description="Field to sort with"), - ascending: Optional[bool] = Query( - None, description="Whether the sorting should be ascending", - ), - ) -> STORE_PARAMS: - - sort = {} - - if field and ascending is not None: - sort.update({field: 1 if ascending else -1}) - - elif field or ascending is not None: - raise HTTPException( - status_code=404, - detail="Must specify both a field and order for sorting.", - ) - - return {"sort": sort} diff --git a/src/mp_api/core/resource.py b/src/mp_api/core/resource.py deleted file mode 100644 index 899ee636..00000000 --- a/src/mp_api/core/resource.py +++ /dev/null @@ -1,507 +0,0 @@ -import os -from abc import ABC, abstractmethod -from typing import List, Dict, Union, Optional, Callable -from starlette.responses import RedirectResponse -from pydantic import BaseModel -from monty.json import MSONable -from fastapi import FastAPI, APIRouter, Path, HTTPException, Depends, Query, Request -from inspect import signature - -from maggma.core import Store - -from mp_api.core.models import Response -from mp_api.core.utils import ( - STORE_PARAMS, - merge_queries, - attach_signature, - dynamic_import, - api_sanitize, -) - -from mp_api.core.settings import MAPISettings - -from mp_api.core.query_operator import ( - QueryOperator, - PaginationQuery, - SparseFieldsQuery, - VersionQuery, -) - - -class Resource(MSONable, ABC): - """ - Base class for a REST Compatible Resource - """ - - def __init__( - self, - store: Store, - model: Union[BaseModel, str] = None, - tags: Optional[List[str]] = None, - query_operators: Optional[List[QueryOperator]] = None, - include_in_schema: Optional[bool] = True, - ): - """ - Args: - store: The Maggma Store to get data from - model: the pydantic model to apply to the documents from the Store - This can be a string with a full python path to a model or - an actual pydantic Model if this is being instantiated in python - code. Serializing this via Monty will auto-convert the pydantic model - into a python path string - tags: list of tags for the Endpoint - query_operators: operators for the query language - include_in_schema: Whether the endpoint should be shown in the documented schema. - """ - self.store = store - self.tags = tags or [] - self.query_operators = query_operators - self.include_in_schema = include_in_schema - - if isinstance(model, str): - module_path = ".".join(model.split(".")[:-1]) - class_name = model.split(".")[-1] - class_model = dynamic_import(module_path, class_name) - assert issubclass( - class_model, BaseModel - ), "The resource model has to be a PyDantic Model" - self.model = api_sanitize(class_model, allow_dict_msonable=True) - elif isinstance(model, type) and issubclass(model, (BaseModel, MSONable)): - self.model = api_sanitize(model, allow_dict_msonable=True) - else: - raise ValueError("The resource model has to be a PyDantic Model") - - self.router = APIRouter() - self.response_model = Response[self.model] # type: ignore - self.setup_redirect() - self.prepare_endpoint() - - @abstractmethod - def prepare_endpoint(self): - """ - Internal method to prepare the endpoint by setting up default handlers - for routes - """ - pass - - def setup_redirect(self): - @self.router.get("", include_in_schema=False) - def redirect_unslashes(): - """ - Redirects unforward slashed url to resource - url with the forward slash - """ - - url = self.router.url_path_for("/") - return RedirectResponse(url=url, status_code=301) - - def setup_indexes(self): - """ - Internal method to ensure indexes in MongoDB - """ - - self.store.connect() - self.store.ensure_index(self.store.key, unique=True) - if self.query_operators is not None: - for query_operator in self.query_operators: - keys = query_operator.ensure_indexes() - if keys: - for (key, unique) in keys: - self.store.ensure_index(key, unique=unique) - - def run(self): # pragma: no cover - """ - Runs the Endpoint cluster locally - This is intended for testing not production - """ - import uvicorn - - app = FastAPI() - app.include_router(self.router, prefix="") - uvicorn.run(app) - - def as_dict(self) -> Dict: - """ - Special as_dict implemented to convert pydantic models into strings - """ - - d = super().as_dict() # Ensures sub-classes serialize correctly - d["model"] = f"{self.model.__module__}.{self.model.__name__}" - return d - - -class GetResource(Resource): - """ - Implements a REST Compatible Resource as a GET URL endpoint - This class provides a number of convenience features - including full pagination, field projection, and the - MAPI query lanaugage - """ - - def __init__( - self, - store: Store, - model: Union[BaseModel, str], - tags: Optional[List[str]] = None, - query_operators: Optional[List[QueryOperator]] = None, - key_fields: Optional[List[str]] = None, - custom_endpoint_funcs: Optional[List[Callable]] = None, - enable_get_by_key: Optional[bool] = True, - enable_default_search: Optional[bool] = True, - include_in_schema: Optional[bool] = True, - ): - """ - Args: - store: The Maggma Store to get data from - model: the pydantic model to apply to the documents from the Store - This can be a string with a full python path to a model or - an actual pydantic Model if this is being instantiated in python - code. Serializing this via Monty will auto-convert the pydantic model - into a python path string - tags: list of tags for the Endpoint - query_operators: operators for the query language - key_fields: List of fields to always project. Default uses SparseFieldsQuery - to allow user's to define these on-the-fly. - custom_endpoint_funcs: Custom endpoint preparation functions to be used - enable_get_by_key: Enable default key route for endpoint. - enable_default_search: Enable default endpoint search behavior. - include_in_schema: Whether the endpoint should be shown in the documented schema. - """ - - self.key_fields = key_fields - self.versioned = False - self.cep = custom_endpoint_funcs - self.enable_get_by_key = enable_get_by_key - self.enable_default_search = enable_default_search - - super().__init__( - store, - model=model, - tags=tags, - query_operators=query_operators, - include_in_schema=include_in_schema, - ) - - self.query_operators = ( - query_operators - if query_operators is not None - else [ - PaginationQuery(), - SparseFieldsQuery( - self.model, - default_fields=[self.store.key, self.store.last_updated_field], - ), - ] - ) # type: list - - if any( - isinstance(qop_entry, VersionQuery) for qop_entry in self.query_operators - ): - self.versioned = True - - self.prepare_endpoint() - - def prepare_endpoint(self): - """ - Internal method to prepare the endpoint by setting up default handlers - for routes - """ - - if self.cep is not None: - for func in self.cep: - func(self) - - if self.enable_get_by_key: - self.build_get_by_key() - - if self.enable_default_search: - self.set_dynamic_model_search() - - def build_get_by_key(self): - key_name = self.store.key - model_name = self.model.__name__ - - if self.key_fields is None: - field_input = SparseFieldsQuery( - self.model, [self.store.key, self.store.last_updated_field] - ).query - else: - - def field_input(): - return {"properties": self.key_fields} - - if not self.versioned: - - async def get_by_key( - key: str = Path( - ..., - alias=key_name, - title=f"The {key_name} of the {model_name} to get", - ), - fields: STORE_PARAMS = Depends(field_input), - ): - f""" - Get's a document by the primary key in the store - - Args: - {key_name}: the id of a single {model_name} - - Returns: - a single {model_name} document - """ - self.store.connect() - - crit = {self.store.key: key} - - if model_name == "TaskDoc": - crit.update({"sbxn": "core"}) - - item = [ - self.store.query_one(criteria=crit, properties=fields["properties"]) - ] - - if item == [None]: - raise HTTPException( - status_code=404, - detail=f"Item with {self.store.key} = {key} not found", - ) - - for operator in self.query_operators: - item = operator.post_process(item) - - response = {"data": item} - - return response - - self.router.get( - f"/{{{key_name}}}/", - response_description=f"Get an {model_name} by {key_name}", - response_model=self.response_model, - response_model_exclude_unset=True, - tags=self.tags, - include_in_schema=self.include_in_schema, - )(get_by_key) - - else: - - async def get_by_key_versioned( - key: str = Path( - ..., - alias=key_name, - title=f"The {key_name} of the {model_name} to get", - ), - fields: STORE_PARAMS = Depends(field_input), - version: str = Query( - None, - description="Database version to query on formatted as YYYY.MM.DD", - ), - ): - f""" - Get's a document by the primary key in the store - - Args: - {key_name}: the id of a single {model_name} - - Returns: - a single {model_name} document - """ - - if version is not None: - version = version.replace(".", "_") - else: - version = os.environ.get("DB_VERSION", MAPISettings().db_version) - - prefix = self.store.collection_name.split("_")[0] - self.store.collection_name = f"{prefix}_{version}" - - self.store.connect(force_reset=True) - - crit = {self.store.key: key} - - if model_name == "TaskDoc": - crit.update({"sbxn": "core"}) - - item = [ - self.store.query_one(criteria=crit, properties=fields["properties"]) - ] - - if item == [None]: - raise HTTPException( - status_code=404, - detail=f"Item with {self.store.key} = {key} not found", - ) - - for operator in self.query_operators: - item = operator.post_process(item) - - response = {"data": item} - - return response - - self.router.get( - f"/{{{key_name}}}/", - response_description=f"Get an {model_name} by {key_name}", - response_model=self.response_model, - response_model_exclude_unset=True, - tags=self.tags, - include_in_schema=self.include_in_schema, - )(get_by_key_versioned) - - def set_dynamic_model_search(self): - - model_name = self.model.__name__ - - async def search(**queries: STORE_PARAMS): - - request: Request = queries.pop("request") # type: ignore - - query: STORE_PARAMS = merge_queries(list(queries.values())) - - query_params = [ - entry - for _, i in enumerate(self.query_operators) - for entry in signature(i.query).parameters - ] - - overlap = [ - key for key in request.query_params.keys() if key not in query_params - ] - if any(overlap): - raise HTTPException( - status_code=404, - detail="Request contains query parameters which cannot be used: {}".format( - ", ".join(overlap) - ), - ) - - if self.versioned: - if query["criteria"].get("version", None) is not None: - version = query["criteria"]["version"].replace(".", "_") - query["criteria"].pop("version") - - else: - version = os.environ.get("DB_VERSION", MAPISettings().db_version) - - prefix = self.store.collection_name.split("_")[0] - self.store.collection_name = f"{prefix}_{version}" - - self.store.connect(force_reset=True) - - if model_name == "TaskDoc": - query["criteria"].update({"sbxn": "core"}) - - data = list(self.store.query(**query)) # type: ignore - operator_metas = [ - operator.meta(self.store, query.get("criteria", {})) - for operator in self.query_operators - ] - meta = {k: v for m in operator_metas for k, v in m.items()} - - for operator in self.query_operators: - data = operator.post_process(data) - - response = {"data": data, "meta": meta} - - return response - - ann = {f"dep{i}": STORE_PARAMS for i, _ in enumerate(self.query_operators)} - ann.update({"request": Request}) - attach_signature( - search, - annotations=ann, - defaults={ - f"dep{i}": Depends(dep.query) - for i, dep in enumerate(self.query_operators) - }, - ) - - self.router.get( - "/", - tags=self.tags, - summary=f"Get {model_name} documents", - response_model=self.response_model, - response_description=f"Search for a {model_name}", - response_model_exclude_unset=True, - include_in_schema=self.include_in_schema, - )(search) - - -class ConsumerPostResource(Resource): - """ - Implements a REST Compatible Resource as a POST URL endpoint - for private consumer data. - """ - - def prepare_endpoint(self): - """ - Internal method to prepare the endpoint by setting up default handlers - for routes - """ - - model_name = self.model.__name__ - - async def search(**queries: STORE_PARAMS): - - request: Request = queries.pop("request") # type: ignore - - query: STORE_PARAMS = merge_queries(list(queries.values())) - - query_params = [ - entry - for _, i in enumerate(self.query_operators) # type: ignore - for entry in signature(i.query).parameters - ] - - overlap = [ - key for key in request.query_params.keys() if key not in query_params - ] - if any(overlap): - raise HTTPException( - status_code=404, - detail="Request contains query parameters which cannot be used: {}".format( - ", ".join(overlap) - ), - ) - - self.store.connect(force_reset=True) - - operator_metas = [ - operator.meta(self.store, query.get("criteria", {})) - for operator in self.query_operators # type: ignore - ] - meta = {k: v for m in operator_metas for k, v in m.items()} - - try: - docs = self.store.update(docs=query["criteria"]) # type: ignore - except Exception: - raise HTTPException( - status_code=404, detail="Problem when trying to post data.", - ) - - for operator in self.query_operators: # type: ignore - data = operator.post_process(docs) - - response = {"data": data, "meta": meta} - - return response - - ann = {f"dep{i}": STORE_PARAMS for i, _ in enumerate(self.query_operators)} - ann.update({"request": Request}) - attach_signature( - search, - annotations=ann, - defaults={ - f"dep{i}": Depends(dep.query) - for i, dep in enumerate(self.query_operators) - }, - ) - - self.router.post( - "/", - tags=self.tags, - summary=f"Post {model_name} data", - response_model=self.response_model, - response_description=f"Posted data {model_name} data", - response_model_exclude_unset=True, - include_in_schema=self.include_in_schema, - )(search) diff --git a/src/mp_api/core/settings.py b/src/mp_api/core/settings.py index 865b261d..729354cb 100644 --- a/src/mp_api/core/settings.py +++ b/src/mp_api/core/settings.py @@ -9,7 +9,9 @@ class MAPISettings(BaseSettings): python module """ - app_path: str = Field("~/mapi.json", description="Path for the default MAPI JSON definition") + app_path: str = Field( + "~/mapi.json", description="Path for the default MAPI JSON definition" + ) debug: bool = Field(False, description="Turns on debug mode for MAPI") diff --git a/src/mp_api/core/utils.py b/src/mp_api/core/utils.py deleted file mode 100644 index c45d3e82..00000000 --- a/src/mp_api/core/utils.py +++ /dev/null @@ -1,162 +0,0 @@ -import inspect -from typing import List, Dict, Callable, Any, Optional, Type -from typing_extensions import Literal -from importlib import import_module - -from monty.json import MSONable -from pydantic.utils import lenient_issubclass -from pydantic.schema import get_flat_models_from_model -from pydantic import BaseModel - -QUERY_PARAMS = ["criteria", "properties", "sort", "skip", "limit"] -STORE_PARAMS = Dict[Literal["criteria", "properties", "sort", "skip", "limit"], Any] - - -def dynamic_import(abs_module_path: str, class_name: str): - """ - Dynamic class importer from: https://www.bnmetrics.com/blog/dynamic-import-in-python3 - """ - module_object = import_module(abs_module_path) - target_class = getattr(module_object, class_name) - return target_class - - -def merge_queries(queries: List[STORE_PARAMS]) -> STORE_PARAMS: - - criteria: STORE_PARAMS = {} - properties: List[str] = [] - - for sub_query in queries: - if "criteria" in sub_query: - criteria.update(sub_query["criteria"]) - if "properties" in sub_query: - properties.extend(sub_query["properties"]) - - remainder = { - k: v - for query in queries - for k, v in query.items() - if k not in ["criteria", "properties"] - } - - return { - "criteria": criteria, - "properties": properties if len(properties) > 0 else None, - **remainder, - } - - -def attach_signature(function: Callable, defaults: Dict, annotations: Dict): - """ - Attaches signature for defaults and annotations for parameters to function - - Args: - function: callable function to attach the signature to - defaults: dictionary of parameters -> default values - annotations: dictionary of type annoations for the parameters - """ - - required_params = [ - inspect.Parameter( - param, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=defaults.get(param, None), - annotation=annotations.get(param, None), - ) - for param in annotations.keys() - if param not in defaults.keys() - ] - - optional_params = [ - inspect.Parameter( - param, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=defaults.get(param, None), - annotation=annotations.get(param, None), - ) - for param in defaults.keys() - ] - - setattr( - function, "__signature__", inspect.Signature(required_params + optional_params) - ) - - -def api_sanitize( - pydantic_model: Type[BaseModel], - fields_to_leave: Optional[List[str]] = None, - allow_dict_msonable=False, -): - """ - Function to clean up pydantic models for the API by: - 1.) Making fields optional - 2.) Allowing dictionaries in-place of the objects for MSONable quantities - - WARNING: This works in place, so it mutates the model and all sub-models - - Args: - fields_to_leave: list of strings for model fields as "model__name__.field" - """ - - models = [ - model - for model in get_flat_models_from_model(pydantic_model) - if lenient_issubclass(model, BaseModel) - ] - - fields_to_leave = fields_to_leave or [] - fields_tuples = [f.split(".") for f in fields_to_leave] - assert all(len(f) == 2 for f in fields_tuples) - - for model in models: - model_fields_to_leave = {f[1] for f in fields_tuples if model.__name__ == f[0]} - for name, field in model.__fields__.items(): # type: ignore - field_type = field.type_ - - if name not in model_fields_to_leave: - field.required = False - field.field_info.default = None - - if ( - field_type is not None - and lenient_issubclass(field_type, MSONable) - and allow_dict_msonable - ): - field.type_ = allow_msonable_dict(field_type) - field.populate_validators() - - return pydantic_model - - -def allow_msonable_dict(monty_cls: Type[MSONable]): - """ - Patch Monty to allow for dict values for MSONable - """ - - def validate_monty(cls, v): - """ - Stub validator for MSONable as a dictionary only - """ - if isinstance(v, cls): - return v - elif isinstance(v, dict): - # Just validate the simple Monty Dict Model - errors = [] - if v.get("@module", "") != monty_cls.__module__: - errors.append("@module") - - if v.get("@class", "") != monty_cls.__name__: - errors.append("@class") - - if len(errors) > 0: - raise ValueError( - "Missing Monty seriailzation fields in dictionary: {errors}" - ) - - return v - else: - raise ValueError(f"Must provide {cls.__name__} or MSONable dictionary") - - setattr(monty_cls, "validate_monty", classmethod(validate_monty)) - - return monty_cls diff --git a/src/mp_api/matproj.py b/src/mp_api/matproj.py index d82e977e..66ea08ec 100644 --- a/src/mp_api/matproj.py +++ b/src/mp_api/matproj.py @@ -62,7 +62,9 @@ def __init__( self.api_key = api_key self.endpoint = endpoint self.version = version - self.session = BaseRester._create_session(api_key=api_key, include_user_agent=include_user_agent) + self.session = BaseRester._create_session( + api_key=api_key, include_user_agent=include_user_agent + ) self._all_resters = [] @@ -99,7 +101,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): # eventually be retired. # @deprecated(self.materials.get_structure_by_material_id, _DEPRECATION_WARNING) - def get_structure_by_material_id(self, material_id, final=True, conventional_unit_cell=False) -> Structure: + def get_structure_by_material_id( + self, material_id, final=True, conventional_unit_cell=False + ) -> Structure: """ Get a Structure corresponding to a material_id. @@ -149,11 +153,15 @@ def get_materials_id_from_task_id(self, task_id, version=None): Returns: materials_id (str) """ - docs = self.materials.search(task_ids=[task_id], fields=["material_id"], version=version) + docs = self.materials.search( + task_ids=[task_id], fields=["material_id"], version=version + ) if len(docs) == 1: return docs[0].material_id elif len(docs) > 1: - raise ValueError(f"Multiple documents return for {task_id}, this should not happen, please report it!") + raise ValueError( + f"Multiple documents return for {task_id}, this should not happen, please report it!" + ) else: warnings.warn( f"No material found containing task {task_id}. Please report it if you suspect a task has gone missing." @@ -171,7 +179,12 @@ def get_materials_ids(self, chemsys_formula): Returns: ([str]) List of all materials ids. """ - return sorted(doc.task_id for doc in self.materials.search_material_docs(chemsys_formula=chemsys_formula)) + return sorted( + doc.task_id + for doc in self.materials.search_material_docs( + chemsys_formula=chemsys_formula + ) + ) def get_structures(self, chemsys_formula_id, energy_above_hull_cutoff=0): """ @@ -243,7 +256,9 @@ def get_entries( """ raise NotImplementedError - def get_pourbaix_entries(self, chemsys, solid_compat=MaterialsProjectCompatibility()): + def get_pourbaix_entries( + self, chemsys, solid_compat=MaterialsProjectCompatibility() + ): """ A helper function to get all entries necessary to generate a pourbaix diagram from the rest interface. @@ -257,7 +272,12 @@ def get_pourbaix_entries(self, chemsys, solid_compat=MaterialsProjectCompatibili raise NotImplementedError def get_entry_by_material_id( - self, material_id, compatible_only=True, inc_structure=None, property_data=None, conventional_unit_cell=False, + self, + material_id, + compatible_only=True, + inc_structure=None, + property_data=None, + conventional_unit_cell=False, ): """ Get a ComputedEntry corresponding to a material_id. @@ -341,7 +361,12 @@ def get_phonon_ddb_by_material_id(self, material_id): raise NotImplementedError def get_entries_in_chemsys( - self, elements, compatible_only=True, inc_structure=None, property_data=None, conventional_unit_cell=False, + self, + elements, + compatible_only=True, + inc_structure=None, + property_data=None, + conventional_unit_cell=False, ): """ Helper method to get a list of ComputedEntries in a chemical system. @@ -406,7 +431,12 @@ def get_exp_entry(self, formula): raise NotImplementedError def query( - self, criteria, properties, chunk_size=500, max_tries_per_chunk=5, mp_decode=True, + self, + criteria, + properties, + chunk_size=500, + max_tries_per_chunk=5, + mp_decode=True, ): r""" @@ -768,7 +798,12 @@ def get_gb_data( raise NotImplementedError def get_interface_reactions( - self, reactant1, reactant2, open_el=None, relative_mu=None, use_hull_energy=False, + self, + reactant1, + reactant2, + open_el=None, + relative_mu=None, + use_hull_energy=False, ): """ Gets critical reactions between two reactants. diff --git a/src/mp_api/routes/_consumer/query_operator.py b/src/mp_api/routes/_consumer/query_operator.py index 3fa558f9..d2f4f065 100644 --- a/src/mp_api/routes/_consumer/query_operator.py +++ b/src/mp_api/routes/_consumer/query_operator.py @@ -1,7 +1,7 @@ from typing import Dict from fastapi import Query, Body -from mp_api.core.utils import STORE_PARAMS -from mp_api.core.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS +from maggma.api.query_operator import QueryOperator class UserSettingsPostQuery(QueryOperator): diff --git a/src/mp_api/routes/_consumer/resources.py b/src/mp_api/routes/_consumer/resources.py index e5401eaf..a8b72e6a 100644 --- a/src/mp_api/routes/_consumer/resources.py +++ b/src/mp_api/routes/_consumer/resources.py @@ -1,30 +1,20 @@ -from mp_api.core.resource import ConsumerPostResource, GetResource +from maggma.api.resource import SubmissionResource +from mp_api.routes._consumer.query_operator import ( + UserSettingsPostQuery, + UserSettingsGetQuery, +) from mp_api.routes._consumer.models import UserSettingsDoc -from mp_api.routes._consumer.query_operator import UserSettingsPostQuery, UserSettingsGetQuery -def set_settings_resource(consumer_settings_store): - resource = ConsumerPostResource( +def settings_resource(consumer_settings_store): + resource = SubmissionResource( consumer_settings_store, UserSettingsDoc, - query_operators=[UserSettingsPostQuery()], - tags=["Consumer"], - include_in_schema=False, - ) - - return resource - - -def get_settings_resource(consumer_settings_store): - resource = GetResource( - consumer_settings_store, - UserSettingsDoc, - query_operators=[UserSettingsGetQuery()], - tags=["Consumer"], - key_fields=["consumer_id", "settings"], - enable_get_by_key=True, + post_query_operators=[UserSettingsPostQuery()], + get_query_operators=[UserSettingsGetQuery()], enable_default_search=False, include_in_schema=False, ) return resource + diff --git a/src/mp_api/routes/charge_density/models.py b/src/mp_api/routes/charge_density/models.py index 6d94b504..c20bb268 100644 --- a/src/mp_api/routes/charge_density/models.py +++ b/src/mp_api/routes/charge_density/models.py @@ -1,6 +1,5 @@ -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field from datetime import datetime -from monty.json import MontyDecoder from pymatgen.io.vasp import Chgcar @@ -9,13 +8,10 @@ class ChgcarDataDoc(BaseModel): Electron charge density for selected materials. """ - fs_id: str = Field( - None, description="Unique object ID for the charge density data." - ) + fs_id: str = Field(None, description="Unique object ID for the charge density data.") last_updated: datetime = Field( - None, - description="Timestamp for the most recent update to the charge density data.", + None, description="Timestamp for the most recent update to the charge density data.", ) task_id: str = Field( diff --git a/src/mp_api/routes/charge_density/resources.py b/src/mp_api/routes/charge_density/resources.py index 9d6cc267..6c0d5cf2 100644 --- a/src/mp_api/routes/charge_density/resources.py +++ b/src/mp_api/routes/charge_density/resources.py @@ -1,78 +1,15 @@ -from fastapi.param_functions import Path -from mp_api.core.resource import GetResource -from mp_api.core.models import Response +from maggma.api.resource import ReadOnlyResource from mp_api.routes.charge_density.models import ChgcarDataDoc -from mp_api.core.query_operator import PaginationQuery, SparseFieldsQuery, SortQuery -from mp_api.core.utils import STORE_PARAMS -from fastapi import HTTPException, Depends +from maggma.api.query_operator import SparseFieldsQuery def charge_density_resource(s3_store): - def custom_charge_density_endpoint_prep(self): - - self.s3 = s3_store - model = ChgcarDataDoc - model_name = model.__name__ - key_name = "task_id" - - field_input = SparseFieldsQuery( - model, [key_name, self.s3.last_updated_field] - ).query - - async def get_chgcar_data( - material_id: str = Path( - ..., - alias=key_name, - title=f"The Material ID ({key_name}) associated with the {model_name}", - ), - fields: STORE_PARAMS = Depends(field_input), - ): - f""" - Get's a document by the primary key in the store - - Args: - material_id: The Materials Project ID ({key_name}) of a single {model_name} - - Returns: - a single {model_name} document - """ - - self.s3.connect() - - item = self.s3.query_one( - {key_name: material_id}, properties=fields["properties"] - ) - - if item is None: - raise HTTPException( - status_code=404, - detail=f"Item with {key_name} = {material_id} not found", - ) - else: - return {"data": [item]} - - self.router.get( - f"/{{{key_name}}}/", - response_description=f"Get an {model_name} by {key_name}", - response_model=Response[model], - response_model_exclude_unset=True, - tags=self.tags, - )(get_chgcar_data) - - resource = GetResource( + resource = ReadOnlyResource( s3_store, ChgcarDataDoc, - query_operators=[ - SortQuery(), - PaginationQuery(), - SparseFieldsQuery( - ChgcarDataDoc, default_fields=["task_id", "last_updated"] - ), - ], tags=["Charge Density"], - custom_endpoint_funcs=[custom_charge_density_endpoint_prep], enable_default_search=False, - enable_get_by_key=False, + enable_get_by_key=True, ) return resource diff --git a/src/mp_api/routes/dielectric/query_operators.py b/src/mp_api/routes/dielectric/query_operators.py index c3f9f466..e2728d3c 100644 --- a/src/mp_api/routes/dielectric/query_operators.py +++ b/src/mp_api/routes/dielectric/query_operators.py @@ -1,6 +1,7 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from collections import defaultdict @@ -12,30 +13,14 @@ class DielectricQuery(QueryOperator): def query( self, - e_total_max: Optional[float] = Query( - None, description="Maximum value for the total dielectric constant.", - ), - e_total_min: Optional[float] = Query( - None, description="Minimum value for the total dielectric constant.", - ), - e_ionic_max: Optional[float] = Query( - None, description="Maximum value for the ionic dielectric constant.", - ), - e_ionic_min: Optional[float] = Query( - None, description="Minimum value for the ionic dielectric constant.", - ), - e_static_max: Optional[float] = Query( - None, description="Maximum value for the static dielectric constant.", - ), - e_static_min: Optional[float] = Query( - None, description="Minimum value for the static dielectric constant.", - ), - n_max: Optional[float] = Query( - None, description="Maximum value for the refractive index.", - ), - n_min: Optional[float] = Query( - None, description="Minimum value for the refractive index.", - ), + e_total_max: Optional[float] = Query(None, description="Maximum value for the total dielectric constant.",), + e_total_min: Optional[float] = Query(None, description="Minimum value for the total dielectric constant.",), + e_ionic_max: Optional[float] = Query(None, description="Maximum value for the ionic dielectric constant.",), + e_ionic_min: Optional[float] = Query(None, description="Minimum value for the ionic dielectric constant.",), + e_static_max: Optional[float] = Query(None, description="Maximum value for the static dielectric constant.",), + e_static_min: Optional[float] = Query(None, description="Minimum value for the static dielectric constant.",), + n_max: Optional[float] = Query(None, description="Maximum value for the refractive index.",), + n_min: Optional[float] = Query(None, description="Minimum value for the refractive index.",), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict diff --git a/src/mp_api/routes/dielectric/resources.py b/src/mp_api/routes/dielectric/resources.py index b3d584aa..6ad11b9f 100644 --- a/src/mp_api/routes/dielectric/resources.py +++ b/src/mp_api/routes/dielectric/resources.py @@ -1,21 +1,19 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.dielectric.models import DielectricDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.dielectric.query_operators import DielectricQuery def dielectric_resource(dielectric_store): - resource = GetResource( + resource = ReadOnlyResource( dielectric_store, DielectricDoc, query_operators=[ DielectricQuery(), SortQuery(), PaginationQuery(), - SparseFieldsQuery( - DielectricDoc, default_fields=["task_id", "last_updated"] - ), + SparseFieldsQuery(DielectricDoc, default_fields=["task_id", "last_updated"]), ], tags=["Dielectric"], ) diff --git a/src/mp_api/routes/dois/resources.py b/src/mp_api/routes/dois/resources.py index b7445298..18992851 100644 --- a/src/mp_api/routes/dois/resources.py +++ b/src/mp_api/routes/dois/resources.py @@ -1,17 +1,14 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.dois.models import DOIDoc -from mp_api.core.query_operator import PaginationQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SparseFieldsQuery def dois_resource(dois_store): - resource = GetResource( + resource = ReadOnlyResource( dois_store, DOIDoc, - query_operators=[ - PaginationQuery(), - SparseFieldsQuery(DOIDoc, default_fields=["task_id", "doi"]), - ], + query_operators=[PaginationQuery(), SparseFieldsQuery(DOIDoc, default_fields=["task_id", "doi"]),], tags=["DOIs"], enable_default_search=False, ) diff --git a/src/mp_api/routes/elasticity/query_operators.py b/src/mp_api/routes/elasticity/query_operators.py index 21a98ae7..72fd04ab 100644 --- a/src/mp_api/routes/elasticity/query_operators.py +++ b/src/mp_api/routes/elasticity/query_operators.py @@ -1,6 +1,7 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from collections import defaultdict @@ -13,28 +14,22 @@ class BulkModulusQuery(QueryOperator): def query( self, k_voigt_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt average of the bulk modulus in GPa.", + None, description="Maximum value for the Voigt average of the bulk modulus in GPa.", ), k_voigt_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt average of the bulk modulus in GPa.", + None, description="Minimum value for the Voigt average of the bulk modulus in GPa.", ), k_reuss_max: Optional[float] = Query( - None, - description="Maximum value for the Reuss average of the bulk modulus in GPa.", + None, description="Maximum value for the Reuss average of the bulk modulus in GPa.", ), k_reuss_min: Optional[float] = Query( - None, - description="Minimum value for the Reuss average of the bulk modulus in GPa.", + None, description="Minimum value for the Reuss average of the bulk modulus in GPa.", ), k_vrh_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt-Reuss-Hill average of the bulk modulus in GPa.", + None, description="Maximum value for the Voigt-Reuss-Hill average of the bulk modulus in GPa.", ), k_vrh_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt-Reuss-Hill average of the bulk modulus in GPa.", + None, description="Minimum value for the Voigt-Reuss-Hill average of the bulk modulus in GPa.", ), ) -> STORE_PARAMS: @@ -64,28 +59,22 @@ class ShearModulusQuery(QueryOperator): def query( self, g_voigt_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt average of the shear modulus in GPa.", + None, description="Maximum value for the Voigt average of the shear modulus in GPa.", ), g_voigt_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt average of the shear modulus in GPa.", + None, description="Minimum value for the Voigt average of the shear modulus in GPa.", ), g_reuss_max: Optional[float] = Query( - None, - description="Maximum value for the Reuss average of the shear modulus in GPa.", + None, description="Maximum value for the Reuss average of the shear modulus in GPa.", ), g_reuss_min: Optional[float] = Query( - None, - description="Minimum value for the Reuss average of the shear modulus in GPa.", + None, description="Minimum value for the Reuss average of the shear modulus in GPa.", ), g_vrh_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt-Reuss-Hill average of the shear modulus in GPa.", + None, description="Maximum value for the Voigt-Reuss-Hill average of the shear modulus in GPa.", ), g_vrh_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt-Reuss-Hill average of the shear modulus in GPa.", + None, description="Minimum value for the Voigt-Reuss-Hill average of the shear modulus in GPa.", ), ) -> STORE_PARAMS: @@ -115,31 +104,16 @@ class PoissonQuery(QueryOperator): def query( self, - elastic_anisotropy_max: Optional[float] = Query( - None, - description="Maximum value for the elastic anisotropy.", - ), - elastic_anisotropy_min: Optional[float] = Query( - None, - description="Maximum value for the elastic anisotropy.", - ), - poisson_max: Optional[float] = Query( - None, - description="Maximum value for Poisson's ratio.", - ), - poisson_min: Optional[float] = Query( - None, - description="Minimum value for Poisson's ratio.", - ), + elastic_anisotropy_max: Optional[float] = Query(None, description="Maximum value for the elastic anisotropy.",), + elastic_anisotropy_min: Optional[float] = Query(None, description="Maximum value for the elastic anisotropy.",), + poisson_max: Optional[float] = Query(None, description="Maximum value for Poisson's ratio.",), + poisson_min: Optional[float] = Query(None, description="Minimum value for Poisson's ratio.",), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict d = { - "elasticity.universal_anisotropy": [ - elastic_anisotropy_min, - elastic_anisotropy_max, - ], + "elasticity.universal_anisotropy": [elastic_anisotropy_min, elastic_anisotropy_max,], "elasticity.homogeneous_poisson": [poisson_min, poisson_max], } @@ -159,11 +133,7 @@ class ChemsysQuery(QueryOperator): """ def query( - self, - chemsys: Optional[str] = Query( - None, - description="Dash-delimited list of elements in the material.", - ), + self, chemsys: Optional[str] = Query(None, description="Dash-delimited list of elements in the material.",), ): crit = {} # type: dict diff --git a/src/mp_api/routes/elasticity/resources.py b/src/mp_api/routes/elasticity/resources.py index c29b59ef..8e4ffd4b 100644 --- a/src/mp_api/routes/elasticity/resources.py +++ b/src/mp_api/routes/elasticity/resources.py @@ -1,7 +1,7 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.elasticity.models import ElasticityDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.elasticity.query_operators import ( ChemsysQuery, BulkModulusQuery, @@ -11,7 +11,7 @@ def elasticity_resource(elasticity_store): - resource = GetResource( + resource = ReadOnlyResource( elasticity_store, ElasticityDoc, query_operators=[ @@ -21,9 +21,7 @@ def elasticity_resource(elasticity_store): PoissonQuery(), SortQuery(), PaginationQuery(), - SparseFieldsQuery( - ElasticityDoc, default_fields=["task_id", "pretty_formula"], - ), + SparseFieldsQuery(ElasticityDoc, default_fields=["task_id", "pretty_formula"],), ], tags=["Elasticity"], ) diff --git a/src/mp_api/routes/electrodes/query_operators.py b/src/mp_api/routes/electrodes/query_operators.py index 00a43457..16539309 100644 --- a/src/mp_api/routes/electrodes/query_operators.py +++ b/src/mp_api/routes/electrodes/query_operators.py @@ -1,6 +1,7 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from mp_api.routes.materials.utils import formula_to_criteria from pymatgen.core.periodic_table import Element @@ -15,8 +16,7 @@ class ElectrodeFormulaQuery(QueryOperator): def query( self, formula: Optional[str] = Query( - None, - description="Query by formula including anonymized formula or by including wild cards", + None, description="Query by formula including anonymized formula or by including wild cards", ), ) -> STORE_PARAMS: @@ -45,36 +45,28 @@ class VoltageStepQuery(QueryOperator): def query( self, delta_volume_max: Optional[float] = Query( - None, - description="Maximum value for the max volume change in percent for a particular voltage step.", + None, description="Maximum value for the max volume change in percent for a particular voltage step.", ), delta_volume_min: Optional[float] = Query( - None, - description="Minimum value for the max volume change in percent for a particular voltage step.", + None, description="Minimum value for the max volume change in percent for a particular voltage step.", ), average_voltage_max: Optional[float] = Query( - None, - description="Maximum value for the average voltage for a particular voltage step in V.", + None, description="Maximum value for the average voltage for a particular voltage step in V.", ), average_voltage_min: Optional[float] = Query( - None, - description="Minimum value for the average voltage for a particular voltage step in V.", + None, description="Minimum value for the average voltage for a particular voltage step in V.", ), max_voltage_max: Optional[float] = Query( - None, - description="Maximum value for the maximum voltage for a particular voltage step in V.", + None, description="Maximum value for the maximum voltage for a particular voltage step in V.", ), max_voltage_min: Optional[float] = Query( - None, - description="Minimum value for the maximum voltage for a particular voltage step in V.", + None, description="Minimum value for the maximum voltage for a particular voltage step in V.", ), min_voltage_max: Optional[float] = Query( - None, - description="Maximum value for the minimum voltage for a particular voltage step in V.", + None, description="Maximum value for the minimum voltage for a particular voltage step in V.", ), min_voltage_min: Optional[float] = Query( - None, - description="Minimum value for the minimum voltage for a particular voltage step in V.", + None, description="Minimum value for the minimum voltage for a particular voltage step in V.", ), capacity_grav_max: Optional[float] = Query( None, description="Maximum value for the gravimetric capacity in maH/g.", @@ -89,36 +81,28 @@ def query( None, description="Minimum value for the volumetric capacity in maH/cc.", ), energy_grav_max: Optional[float] = Query( - None, - description="Maximum value for the gravimetric energy (specific energy) in Wh/kg.", + None, description="Maximum value for the gravimetric energy (specific energy) in Wh/kg.", ), energy_grav_min: Optional[float] = Query( - None, - description="Minimum value for the gravimetric energy (specific energy) in Wh/kg.", + None, description="Minimum value for the gravimetric energy (specific energy) in Wh/kg.", ), energy_vol_max: Optional[float] = Query( - None, - description="Maximum value for the volumetric energy (energy_density) in Wh/l.", + None, description="Maximum value for the volumetric energy (energy_density) in Wh/l.", ), energy_vol_min: Optional[float] = Query( - None, - description="Minimum value for the volumetric energy (energy_density) in Wh/l.", + None, description="Minimum value for the volumetric energy (energy_density) in Wh/l.", ), fracA_charge_max: Optional[float] = Query( - None, - description="Maximum value for the atomic fraction of the working ion in the charged state.", + None, description="Maximum value for the atomic fraction of the working ion in the charged state.", ), fracA_charge_min: Optional[float] = Query( - None, - description="Minimum value for the atomic fraction of the working ion in the charged state.", + None, description="Minimum value for the atomic fraction of the working ion in the charged state.", ), fracA_discharge_max: Optional[float] = Query( - None, - description="Maximum value for the atomic fraction of the working ion in the discharged state.", + None, description="Maximum value for the atomic fraction of the working ion in the discharged state.", ), fracA_discharge_min: Optional[float] = Query( - None, - description="Minimum value for the atomic fraction of the working ion in the discharged state.", + None, description="Minimum value for the atomic fraction of the working ion in the discharged state.", ), ) -> STORE_PARAMS: @@ -165,20 +149,16 @@ class InsertionVoltageStepQuery(QueryOperator): def query( self, stability_charge_max: Optional[float] = Query( - None, - description="The maximum value of the energy above hull of the charged material.", + None, description="The maximum value of the energy above hull of the charged material.", ), stability_charge_min: Optional[float] = Query( - None, - description="The minimum value of the energy above hull of the charged material.", + None, description="The minimum value of the energy above hull of the charged material.", ), stability_discharge_max: Optional[float] = Query( - None, - description="The maximum value of the energy above hull of the discharged material.", + None, description="The maximum value of the energy above hull of the discharged material.", ), stability_discharge_min: Optional[float] = Query( - None, - description="The minimum value of the energy above hull of the discharged material.", + None, description="The minimum value of the energy above hull of the discharged material.", ), ) -> STORE_PARAMS: @@ -216,9 +196,7 @@ class InsertionElectrodeQuery(QueryOperator): def query( self, - working_ion: Optional[Element] = Query( - None, title="Element of the working ion" - ), + working_ion: Optional[Element] = Query(None, title="Element of the working ion"), num_steps_max: Optional[float] = Query( None, description="The maximum value of the The number of distinct voltage steps from fully charge to \ @@ -230,12 +208,10 @@ def query( discharge based on the stable intermediate states.", ), max_voltage_step_max: Optional[float] = Query( - None, - description="The maximum value of the maximum absolute difference in adjacent voltage steps.", + None, description="The maximum value of the maximum absolute difference in adjacent voltage steps.", ), max_voltage_step_min: Optional[float] = Query( - None, - description="The minimum value of maximum absolute difference in adjacent voltage steps.", + None, description="The minimum value of maximum absolute difference in adjacent voltage steps.", ), ) -> STORE_PARAMS: diff --git a/src/mp_api/routes/electrodes/resources.py b/src/mp_api/routes/electrodes/resources.py index ed63fdee..560134d7 100644 --- a/src/mp_api/routes/electrodes/resources.py +++ b/src/mp_api/routes/electrodes/resources.py @@ -1,26 +1,20 @@ -from mp_api.core.resource import GetResource +from maggma.api.query_operator.dynamic import NumericQuery +from maggma.api.resource import ReadOnlyResource from emmet.core.electrode import InsertionElectrodeDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery -from mp_api.routes.electrodes.query_operators import ( - VoltageStepQuery, - InsertionVoltageStepQuery, - InsertionElectrodeQuery, - ElectrodeFormulaQuery, -) +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from mp_api.routes.electrodes.query_operators import ElectrodeFormulaQuery from mp_api.routes.materials.query_operators import ElementsQuery def insertion_electrodes_resource(insertion_electrodes_store): - resource = GetResource( + resource = ReadOnlyResource( insertion_electrodes_store, InsertionElectrodeDoc, query_operators=[ ElectrodeFormulaQuery(), ElementsQuery(), - VoltageStepQuery(), - InsertionVoltageStepQuery(), - InsertionElectrodeQuery(), + NumericQuery(model=InsertionElectrodeDoc), SortQuery(), PaginationQuery(), SparseFieldsQuery( diff --git a/src/mp_api/routes/electronic_structure/query_operators.py b/src/mp_api/routes/electronic_structure/query_operators.py index de0ef861..a6f3ee33 100644 --- a/src/mp_api/routes/electronic_structure/query_operators.py +++ b/src/mp_api/routes/electronic_structure/query_operators.py @@ -3,8 +3,10 @@ from pymatgen.analysis.magnetism.analyzer import Ordering from pymatgen.electronic_structure.core import Spin, OrbitalType from pymatgen.core.periodic_table import Element +from emmet.core.mpid import MPID -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from mp_api.routes.electronic_structure.models.core import BSPathType, DOSProjectionType from collections import defaultdict @@ -17,29 +19,19 @@ class ESSummaryDataQuery(QueryOperator): def query( self, - band_gap_max: Optional[float] = Query(None, description="Maximum value for the band gap energy in eV."), - band_gap_min: Optional[float] = Query(None, description="Minimum value for the band gap energy in eV."), - efermi_max: Optional[float] = Query(None, description="Maximum value for the fermi energy in eV."), - efermi_min: Optional[float] = Query(None, description="Minimum value for the fermi energy in eV."), - magnetic_ordering: Optional[Ordering] = Query(None, description="Magnetic ordering associated with the data."), - is_gap_direct: Optional[bool] = Query(None, description="Whether a band gap is direct or not."), - is_metal: Optional[bool] = Query(None, description="Whether the material is considered a metal."), + magnetic_ordering: Optional[Ordering] = Query( + None, description="Magnetic ordering associated with the data." + ), + is_gap_direct: Optional[bool] = Query( + None, description="Whether a band gap is direct or not." + ), + is_metal: Optional[bool] = Query( + None, description="Whether the material is considered a metal." + ), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict - d = { - "band_gap": [band_gap_min, band_gap_max], - "efermi": [efermi_min, efermi_max], - } - - for entry in d: - if d[entry][0]: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1]: - crit[entry]["$lte"] = d[entry][1] - if magnetic_ordering: crit["magnetic_ordering"] = magnetic_ordering.value @@ -68,13 +60,27 @@ def query( path_type: Optional[BSPathType] = Query( None, description="k-path selection convention for the band structure.", ), - band_gap_max: Optional[float] = Query(None, description="Maximum value for the band gap energy in eV."), - band_gap_min: Optional[float] = Query(None, description="Minimum value for the band gap energy in eV."), - efermi_max: Optional[float] = Query(None, description="Maximum value for the fermi energy in eV."), - efermi_min: Optional[float] = Query(None, description="Minimum value for the fermi energy in eV."), - magnetic_ordering: Optional[Ordering] = Query(None, description="Magnetic ordering associated with the data."), - is_gap_direct: Optional[bool] = Query(None, description="Whether a band gap is direct or not."), - is_metal: Optional[bool] = Query(None, description="Whether the material is considered a metal."), + band_gap_max: Optional[float] = Query( + None, description="Maximum value for the band gap energy in eV." + ), + band_gap_min: Optional[float] = Query( + None, description="Minimum value for the band gap energy in eV." + ), + efermi_max: Optional[float] = Query( + None, description="Maximum value for the fermi energy in eV." + ), + efermi_min: Optional[float] = Query( + None, description="Minimum value for the fermi energy in eV." + ), + magnetic_ordering: Optional[Ordering] = Query( + None, description="Magnetic ordering associated with the data." + ), + is_gap_direct: Optional[bool] = Query( + None, description="Whether a band gap is direct or not." + ), + is_metal: Optional[bool] = Query( + None, description="Whether the material is considered a metal." + ), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict @@ -82,7 +88,10 @@ def query( if path_type is not None: d = { - f"bandstructure.{path_type.value}.band_gap": [band_gap_min, band_gap_max], + f"bandstructure.{path_type.value}.band_gap": [ + band_gap_min, + band_gap_max, + ], f"bandstructure.{path_type.value}.efermi": [efermi_min, efermi_max], } @@ -94,7 +103,9 @@ def query( crit[entry]["$lte"] = d[entry][1] if magnetic_ordering: - crit[f"bandstructure.{path_type.value}.magnetic_ordering"] = magnetic_ordering.value + crit[ + f"bandstructure.{path_type.value}.magnetic_ordering" + ] = magnetic_ordering.value if is_gap_direct is not None: crit[f"bandstructure.{path_type.value}.is_gap_direct"] = is_gap_direct @@ -126,15 +137,30 @@ def query( None, description="Projection type for the density of states data.", ), spin: Optional[Union[Literal["1", "-1"], Spin]] = Query( - None, description="Spin channel for density of states data. '1' corresponds to spin up.", - ), - element: Optional[Element] = Query(None, description="Element type for projected density of states data.",), - orbital: Optional[OrbitalType] = Query(None, description="Orbital type for projected density of states data.",), - band_gap_max: Optional[float] = Query(None, description="Maximum value for the band gap energy in eV."), - band_gap_min: Optional[float] = Query(None, description="Minimum value for the band gap energy in eV."), - efermi_max: Optional[float] = Query(None, description="Maximum value for the fermi energy in eV."), - efermi_min: Optional[float] = Query(None, description="Minimum value for the fermi energy in eV."), - magnetic_ordering: Optional[Ordering] = Query(None, description="Magnetic ordering associated with the data."), + None, + description="Spin channel for density of states data. '1' corresponds to spin up.", + ), + element: Optional[Element] = Query( + None, description="Element type for projected density of states data.", + ), + orbital: Optional[OrbitalType] = Query( + None, description="Orbital type for projected density of states data.", + ), + band_gap_max: Optional[float] = Query( + None, description="Maximum value for the band gap energy in eV." + ), + band_gap_min: Optional[float] = Query( + None, description="Minimum value for the band gap energy in eV." + ), + efermi_max: Optional[float] = Query( + None, description="Maximum value for the fermi energy in eV." + ), + efermi_min: Optional[float] = Query( + None, description="Minimum value for the fermi energy in eV." + ), + magnetic_ordering: Optional[Ordering] = Query( + None, description="Magnetic ordering associated with the data." + ), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict @@ -145,7 +171,8 @@ def query( if projection_type is not None: if spin is None: raise HTTPException( - status_code=400, detail="Must specify a spin channel for querying dos summary data.", + status_code=400, + detail="Must specify a spin channel for querying dos summary data.", ) else: @@ -181,7 +208,9 @@ def query( key_prefix = f"element.{str(element.value)}.{str(orbital.name)}.{str(spin.value)}" else: - key_prefix = f"element.{str(element.value)}.total.{str(spin.value)}" + key_prefix = ( + f"element.{str(element.value)}.total.{str(spin.value)}" + ) key = f"dos.{key_prefix}.{entry}" @@ -203,3 +232,20 @@ def ensure_indexes(self): keys.append(f"dos.{proj_type.value}.$**") return [(key, False) for key in keys] + + +class ObjectQuery(QueryOperator): + """ + Method to generate a query on electronic structure object data. + """ + + def query( + self, + task_id: MPID = Query( + ..., + description=f"The calculation (task) ID associated with the data object", + ), + ) -> STORE_PARAMS: + + return {"criteria": {"task_id": str(task_id)}} + diff --git a/src/mp_api/routes/electronic_structure/resources.py b/src/mp_api/routes/electronic_structure/resources.py index 449f2e30..bac85c04 100644 --- a/src/mp_api/routes/electronic_structure/resources.py +++ b/src/mp_api/routes/electronic_structure/resources.py @@ -1,36 +1,31 @@ -from mp_api.core.resource import GetResource +from maggma.api.query_operator.dynamic import NumericQuery +from maggma.api.resource import ReadOnlyResource from emmet.core.electronic_structure import ElectronicStructureDoc - -from fastapi import HTTPException -from fastapi.param_functions import Path, Query - -from mp_api.core.utils import api_sanitize -from mp_api.core.models import Response -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.materials.query_operators import ( ElementsQuery, FormulaQuery, - MinMaxQuery, ) from mp_api.routes.electronic_structure.query_operators import ( ESSummaryDataQuery, BSDataQuery, DOSDataQuery, + ObjectQuery, ) from mp_api.routes.electronic_structure.models.doc import BSObjectDoc, DOSObjectDoc def es_resource(es_store): - resource = GetResource( + resource = ReadOnlyResource( es_store, ElectronicStructureDoc, query_operators=[ ESSummaryDataQuery(), FormulaQuery(), ElementsQuery(), - MinMaxQuery(), + NumericQuery(model=ElectronicStructureDoc), SortQuery(), PaginationQuery(), SparseFieldsQuery( @@ -43,59 +38,8 @@ def es_resource(es_store): return resource -def bs_resource(es_store, s3_store): - def custom_bs_endpoint_prep(self): - - self.s3 = s3_store - model = api_sanitize(BSObjectDoc, allow_dict_msonable=True) - model_name = model.__name__ - key_name = "task_id" - - async def get_object( - task_id: str = Query( - ..., alias=key_name, title=f"The {key_name} of the {model_name} to get", - ), - ): - f""" - Get's a document by the primary key in the store - - Args: - {key_name}: the calculation id of a single {model_name} - - Returns: - a single {model_name} document - """ - - self.s3.connect() - - bs_object_doc = None - - try: - bs_object_doc = self.s3.query_one({"task_id": task_id}) - - if not bs_object_doc: - raise HTTPException( - status_code=404, - detail=f"Band structure with task_id = {task_id} not found", - ) - - except ValueError: - raise HTTPException( - status_code=404, - detail=f"Band structure with task_id = {task_id} not found", - ) - - return {"data": [bs_object_doc]} - - self.router.get( - "/object/", - response_description=f"Get an {model_name} by {key_name}", - response_model=Response[model], - response_model_exclude_unset=True, - tags=self.tags, - )(get_object) - - resource = GetResource( +def bs_resource(es_store): + resource = ReadOnlyResource( es_store, ElectronicStructureDoc, query_operators=[ @@ -109,65 +53,30 @@ async def get_object( ], tags=["Electronic Structure"], enable_get_by_key=False, - custom_endpoint_funcs=[custom_bs_endpoint_prep], + sub_path="/bandstructure/", ) return resource -def dos_resource(es_store, s3_store): - def custom_dos_endpoint_prep(self): - - self.s3 = s3_store - model = api_sanitize(DOSObjectDoc, allow_dict_msonable=True) - model_name = model.__name__ - key_name = "task_id" - - async def get_object( - task_id: str = Query( - ..., alias=key_name, title=f"The {key_name} of the {model_name} to get", - ), - ): - f""" - Get's a document by the primary key in the store - - Args: - {key_name}: the calculation id of a single {model_name} - - Returns: - a single {model_name} document - """ - - self.s3.connect() - - dos_object_doc = None - - try: - dos_object_doc = self.s3.query_one({"task_id": task_id}) - - if not dos_object_doc: - raise HTTPException( - status_code=404, - detail=f"Density of states with task_id = {task_id} not found", - ) - - except ValueError: - raise HTTPException( - status_code=404, - detail=f"Density of states with task_id = {task_id} not found", - ) - - return {"data": [dos_object_doc]} +def bs_obj_resource(s3_store): + resource = ReadOnlyResource( + s3_store, + BSObjectDoc, + query_operators=[ + ObjectQuery(), + SparseFieldsQuery(BSObjectDoc, default_fields=["task_id", "last_updated"]), + ], + tags=["Electronic Structure"], + enable_get_by_key=False, + enable_default_search=True, + sub_path="/bandstructure/object/", + ) + return resource - self.router.get( - "/object/", - response_description=f"Get an {model_name} by {key_name}", - response_model=Response[model], - response_model_exclude_unset=True, - tags=self.tags, - )(get_object) - resource = GetResource( +def dos_resource(es_store): + resource = ReadOnlyResource( es_store, ElectronicStructureDoc, query_operators=[ @@ -180,8 +89,24 @@ async def get_object( ), ], tags=["Electronic Structure"], - custom_endpoint_funcs=[custom_dos_endpoint_prep], enable_get_by_key=False, + sub_path="/dos/", ) return resource + + +def dos_obj_resource(s3_store): + resource = ReadOnlyResource( + s3_store, + DOSObjectDoc, + query_operators=[ + ObjectQuery(), + SparseFieldsQuery(DOSObjectDoc, default_fields=["task_id", "last_updated"]), + ], + tags=["Electronic Structure"], + enable_get_by_key=False, + enable_default_search=True, + sub_path="/dos/object/", + ) + return resource diff --git a/src/mp_api/routes/eos/query_operators.py b/src/mp_api/routes/eos/query_operators.py deleted file mode 100644 index 4cc52bb9..00000000 --- a/src/mp_api/routes/eos/query_operators.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Optional -from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator - -from collections import defaultdict - - -class EnergyVolumeQuery(QueryOperator): - """ - Method to generate a query for ranges of equations of state energies and volumes - """ - - def query( - self, - volume_max: Optional[float] = Query( - None, description="Maximum value for the cell volume in A³/atom.", - ), - volume_min: Optional[float] = Query( - None, description="Minimum value for the cell volume in A³/atom.", - ), - energy_max: Optional[float] = Query( - None, description="Maximum value for the energy in eV/atom.", - ), - energy_min: Optional[float] = Query( - None, description="Minimum value for the energy in eV/atom.", - ), - ) -> STORE_PARAMS: - - crit = defaultdict(dict) # type: dict - - d = { - "volume": [volume_min, volume_max], - "energy": [energy_min, energy_max], - } - - for entry in d: - if d[entry][0]: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1]: - crit[entry]["$lte"] = d[entry][1] - - return {"criteria": crit} - - def ensure_indexes(self): - keys = ["volume", "energy"] - return [(key, False) for key in keys] diff --git a/src/mp_api/routes/eos/resources.py b/src/mp_api/routes/eos/resources.py index 48cc5ca6..c1dfc937 100644 --- a/src/mp_api/routes/eos/resources.py +++ b/src/mp_api/routes/eos/resources.py @@ -1,16 +1,16 @@ -from mp_api.core.resource import GetResource +from maggma.api.query_operator.dynamic import NumericQuery +from maggma.api.resource import ReadOnlyResource from mp_api.routes.eos.models import EOSDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery -from mp_api.routes.eos.query_operators import EnergyVolumeQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery def eos_resource(eos_store): - resource = GetResource( + resource = ReadOnlyResource( eos_store, EOSDoc, query_operators=[ - EnergyVolumeQuery(), + NumericQuery(model=EOSDoc), SortQuery(), PaginationQuery(), SparseFieldsQuery(EOSDoc, default_fields=["task_id"]), diff --git a/src/mp_api/routes/fermi/resources.py b/src/mp_api/routes/fermi/resources.py index 995fd51d..330d615f 100644 --- a/src/mp_api/routes/fermi/resources.py +++ b/src/mp_api/routes/fermi/resources.py @@ -1,17 +1,14 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.fermi.models import FermiDoc -from mp_api.core.query_operator import PaginationQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SparseFieldsQuery def fermi_resource(fermi_store): - resource = GetResource( + resource = ReadOnlyResource( fermi_store, FermiDoc, - query_operators=[ - PaginationQuery(), - SparseFieldsQuery(FermiDoc, default_fields=["task_id", "last_updated"]), - ], + query_operators=[PaginationQuery(), SparseFieldsQuery(FermiDoc, default_fields=["task_id", "last_updated"]),], tags=["Electronic Structure"], ) diff --git a/src/mp_api/routes/grain_boundary/query_operators.py b/src/mp_api/routes/grain_boundary/query_operators.py index b5a4199a..ec17b70f 100644 --- a/src/mp_api/routes/grain_boundary/query_operators.py +++ b/src/mp_api/routes/grain_boundary/query_operators.py @@ -1,55 +1,13 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from collections import defaultdict from mp_api.routes.grain_boundary.models import GBTypeEnum -class GBEnergyQuery(QueryOperator): - """ - Method to generate a query for energy values associated with grain boundary data - """ - - def query( - self, - gb_energy_max: Optional[float] = Query( - None, description="Maximum value for the grain boundary energy in J/m^2.", - ), - gb_energy_min: Optional[float] = Query( - None, description="Minimum value for the grain boundary energy in J/m^2.", - ), - w_sep_energy_max: Optional[float] = Query( - None, - description="Maximum value for the work of separation energy in J/m^2.", - ), - w_sep_energy_min: Optional[float] = Query( - None, description="Minimum value for work of separation energy in J/m^2.", - ), - ) -> STORE_PARAMS: - - crit = defaultdict(dict) # type: dict - - d = { - "gb_energy": [gb_energy_min, gb_energy_max], - "w_sep": [w_sep_energy_min, w_sep_energy_max], - } - - for entry in d: - if d[entry][0]: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1]: - crit[entry]["$lte"] = d[entry][1] - - return {"criteria": crit} - - def ensure_indexes(self): - keys = ["gb_energy", "w_sep"] - return [(key, False) for key in keys] - - class GBStructureQuery(QueryOperator): """ Method to generate a query for structure related data associated with grain boundary entries @@ -57,12 +15,6 @@ class GBStructureQuery(QueryOperator): def query( self, - rotation_angle_max: Optional[float] = Query( - None, description="Maximum value for the rotation angle in degrees.", - ), - rotation_angle_min: Optional[float] = Query( - None, description="Minimum value for the rotation angle in degrees.", - ), sigma: Optional[float] = Query(None, description="Value of sigma.",), type: Optional[GBTypeEnum] = Query(None, description="Grain boundary type.",), chemsys: Optional[str] = Query( @@ -72,17 +24,6 @@ def query( crit = defaultdict(dict) # type: dict - d = { - "rotation_angle": [rotation_angle_min, rotation_angle_max], - } - - for entry in d: - if d[entry][0]: - crit[entry] = {"$gte": d[entry][0]} - - if d[entry][1]: - crit[entry] = {"$lte": d[entry][1]} - if sigma: crit["sigma"] = sigma diff --git a/src/mp_api/routes/grain_boundary/resources.py b/src/mp_api/routes/grain_boundary/resources.py index 7058d45f..a4a08f04 100644 --- a/src/mp_api/routes/grain_boundary/resources.py +++ b/src/mp_api/routes/grain_boundary/resources.py @@ -1,21 +1,33 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.grain_boundary.models import GrainBoundaryDoc -from mp_api.routes.grain_boundary.query_operators import GBEnergyQuery, GBStructureQuery, GBTaskIDQuery -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from mp_api.routes.grain_boundary.query_operators import ( + GBStructureQuery, + GBTaskIDQuery, +) +from maggma.api.query_operator import ( + PaginationQuery, + SortQuery, + SparseFieldsQuery, + NumericQuery, +) def gb_resource(gb_store): - resource = GetResource( + resource = ReadOnlyResource( gb_store, GrainBoundaryDoc, query_operators=[ GBTaskIDQuery(), - GBEnergyQuery(), + NumericQuery( + model=GrainBoundaryDoc, excluded_fields=["rotation_axis", "gb_plane"] + ), GBStructureQuery(), SortQuery(), PaginationQuery(), - SparseFieldsQuery(GrainBoundaryDoc, default_fields=["task_id", "last_updated"]), + SparseFieldsQuery( + GrainBoundaryDoc, default_fields=["task_id", "last_updated"] + ), ], tags=["Grain Boundaries"], enable_get_by_key=False, diff --git a/src/mp_api/routes/magnetism/query_operators.py b/src/mp_api/routes/magnetism/query_operators.py index 4de17eaa..0299c646 100644 --- a/src/mp_api/routes/magnetism/query_operators.py +++ b/src/mp_api/routes/magnetism/query_operators.py @@ -1,6 +1,7 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from collections import defaultdict @@ -14,9 +15,7 @@ class MagneticQuery(QueryOperator): def query( self, - ordering: Optional[MagneticOrderingEnum] = Query( - None, description="Magnetic ordering of the material." - ), + ordering: Optional[MagneticOrderingEnum] = Query(None, description="Magnetic ordering of the material."), total_magnetization_max: Optional[float] = Query( None, description="Maximum value for the total magnetization.", ), @@ -24,20 +23,16 @@ def query( None, description="Minimum value for the total magnetization.", ), total_magnetization_normalized_vol_max: Optional[float] = Query( - None, - description="Maximum value for the total magnetization normalized with volume.", + None, description="Maximum value for the total magnetization normalized with volume.", ), total_magnetization_normalized_vol_min: Optional[float] = Query( - None, - description="Minimum value for the total magnetization normalized with volume.", + None, description="Minimum value for the total magnetization normalized with volume.", ), total_magnetization_normalized_formula_units_max: Optional[float] = Query( - None, - description="Maximum value for the total magnetization normalized with formula units.", + None, description="Maximum value for the total magnetization normalized with formula units.", ), total_magnetization_normalized_formula_units_min: Optional[float] = Query( - None, - description="Minimum value for the total magnetization normalized with formula units.", + None, description="Minimum value for the total magnetization normalized with formula units.", ), num_magnetic_sites_max: Optional[int] = Query( None, description="Maximum value for the total number of magnetic sites.", @@ -46,22 +41,17 @@ def query( None, description="Minimum value for the total number of magnetic sites.", ), num_unique_magnetic_sites_max: Optional[int] = Query( - None, - description="Maximum value for the total number of unique magnetic sites.", + None, description="Maximum value for the total number of unique magnetic sites.", ), num_unique_magnetic_sites_min: Optional[int] = Query( - None, - description="Minimum value for the total number of unique magnetic sites.", + None, description="Minimum value for the total number of unique magnetic sites.", ), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict d = { - "magnetism.total_magnetization": [ - total_magnetization_min, - total_magnetization_max, - ], + "magnetism.total_magnetization": [total_magnetization_min, total_magnetization_max,], "magnetism.total_magnetization_normalized_vol": [ total_magnetization_normalized_vol_min, total_magnetization_normalized_vol_max, @@ -70,14 +60,8 @@ def query( total_magnetization_normalized_formula_units_min, total_magnetization_normalized_formula_units_max, ], - "magnetism.num_magnetic_sites": [ - num_magnetic_sites_min, - num_magnetic_sites_max, - ], - "magnetism.num_unique_magnetic_sites": [ - num_unique_magnetic_sites_min, - num_unique_magnetic_sites_max, - ], + "magnetism.num_magnetic_sites": [num_magnetic_sites_min, num_magnetic_sites_max,], + "magnetism.num_unique_magnetic_sites": [num_unique_magnetic_sites_min, num_unique_magnetic_sites_max,], } # type: dict for entry in d: diff --git a/src/mp_api/routes/magnetism/resources.py b/src/mp_api/routes/magnetism/resources.py index ca77e61b..cdcbe209 100644 --- a/src/mp_api/routes/magnetism/resources.py +++ b/src/mp_api/routes/magnetism/resources.py @@ -1,12 +1,12 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.magnetism.models import MagnetismDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.magnetism.query_operators import MagneticQuery def magnetism_resource(magnetism_store): - resource = GetResource( + resource = ReadOnlyResource( magnetism_store, MagnetismDoc, query_operators=[ diff --git a/src/mp_api/routes/materials/client.py b/src/mp_api/routes/materials/client.py index e8013f63..46794d90 100644 --- a/src/mp_api/routes/materials/client.py +++ b/src/mp_api/routes/materials/client.py @@ -1,8 +1,6 @@ -import warnings from typing import List, Optional, Tuple from pymatgen.core.structure import Structure -# from mp_api.routes.materials.models import Structure from emmet.core.material import MaterialsDoc from emmet.core.symmetry import CrystalSystem diff --git a/src/mp_api/routes/materials/query_operators.py b/src/mp_api/routes/materials/query_operators.py index b326b459..f9514989 100644 --- a/src/mp_api/routes/materials/query_operators.py +++ b/src/mp_api/routes/materials/query_operators.py @@ -1,10 +1,15 @@ +from itertools import permutations from typing import Optional -from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator -from mp_api.routes.materials.utils import formula_to_criteria + from emmet.core.symmetry import CrystalSystem +from fastapi import Body, HTTPException, Query +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS +from mp_api.routes.materials.utils import formula_to_criteria +from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher +from pymatgen.core.composition import Composition, CompositionError from pymatgen.core.periodic_table import Element -from collections import defaultdict +from pymatgen.core.structure import Structure class FormulaQuery(QueryOperator): @@ -88,67 +93,6 @@ def query( return {"criteria": crit} -class MinMaxQuery(QueryOperator): - """ - Method to generate a query for quantities with a definable min and max - """ - - def query( - self, - nsites_max: Optional[int] = Query( - None, description="Maximum value for the number of sites", - ), - nsites_min: Optional[int] = Query( - None, description="Minimum value for the number of sites", - ), - nelements_max: Optional[float] = Query( - None, description="Maximum value for the number of elements.", - ), - nelements_min: Optional[float] = Query( - None, description="Minimum value for the number of elements.", - ), - volume_max: Optional[float] = Query( - None, description="Maximum value for the cell volume", - ), - volume_min: Optional[float] = Query( - None, description="Minimum value for the cell volume", - ), - density_max: Optional[float] = Query( - None, description="Maximum value for the density", - ), - density_min: Optional[float] = Query( - None, description="Minimum value for the density", - ), - ) -> STORE_PARAMS: - - crit = defaultdict(dict) # type: dict - - entries = { - "nsites": [nsites_min, nsites_max], - "nelements": [nelements_min, nelements_max], - "volume": [volume_min, volume_max], - "density": [density_min, density_max], - } # type: dict - - for entry in entries: - if entries[entry][0]: - crit[entry]["$gte"] = entries[entry][0] - - if entries[entry][1]: - crit[entry]["$lte"] = entries[entry][1] - - return {"criteria": crit} - - def ensure_indexes(self): - keys = self._keys_from_query() - indexes = [] - for key in keys: - if "_min" in key: - key = key.replace("_min", "") - indexes.append((key, False)) - return indexes - - class SymmetryQuery(QueryOperator): """ Method to generate a query on symmetry information @@ -226,3 +170,177 @@ def query( crit.update({"material_id": {"$in": material_ids.split(",")}}) return {"criteria": crit} + + +class FindStructureQuery(QueryOperator): + """ + Method to generate a find structure query + """ + + def query( + self, + structure: Structure = Body( + ..., description="Pymatgen structure object to query with", + ), + ltol: float = Query( + 0.2, description="Fractional length tolerance. Default is 0.2.", + ), + stol: float = Query( + 0.3, + description="Site tolerance. Defined as the fraction of the average free \ + length per atom := ( V / Nsites ) ** (1/3). Default is 0.3.", + ), + angle_tol: float = Query( + 5, description="Angle tolerance in degrees. Default is 5 degrees.", + ), + limit: int = Query( + 1, + description="Maximum number of matches to show. Defaults to 1, only showing the best match.", + ), + ) -> STORE_PARAMS: + + self.ltol = ltol + self.stol = stol + self.angle_tol = angle_tol + self.limit = limit + self.structure = structure + + crit = {} + + try: + s = Structure.from_dict(structure) + except Exception: + raise HTTPException( + status_code=404, + detail="Body cannot be converted to a pymatgen structure object.", + ) + + crit.update({"composition_reduced": dict(s.composition.to_reduced_dict)}) + + return {"criteria": crit} + + def post_process(self, docs): + + s1 = Structure.from_dict(self.structure) + + m = StructureMatcher( + ltol=self.ltol, + stol=self.stol, + angle_tol=self.angle_tol, + primitive_cell=True, + scale=True, + attempt_supercell=False, + comparator=ElementComparator(), + ) + + matches = [] + + for doc in docs: + + s2 = Structure.from_dict(doc["structure"]) + matched = m.fit(s1, s2) + + if matched: + rms = m.get_rms_dist(s1, s2) + + matches.append( + { + "material_id": doc["material_id"], + "normalized_rms_displacement": rms[0], + "max_distance_paired_sites": rms[1], + } + ) + + response = sorted( + matches[: self.limit], + key=lambda x: ( + x["normalized_rms_displacement"], + x["max_distance_paired_sites"], + ), + ) + + return response + + def ensure_indexes(self): + return [("composition_reduced", False)] + + +class FormulaAutoCompleteQuery(QueryOperator): + """ + Method to generate a formula autocomplete query + """ + + def query( + self, + formula: str = Query(..., description="Human readable chemical formula.",), + limit: int = Query( + 10, description="Maximum number of matches to show. Defaults to 10.", + ), + ) -> STORE_PARAMS: + + self.formula = formula + self.limit = limit + + try: + comp = Composition(formula) + except CompositionError: + raise HTTPException( + status_code=400, detail="Invalid formula provided.", + ) + + ind_str = [] + eles = [] + + if len(comp) == 1: + d = comp.get_integer_formula_and_factor() + + s = d[0] + str(int(d[1])) if d[1] != 1 else d[0] + + ind_str.append(s) + eles.append(d[0]) + else: + + comp_red = comp.reduced_composition.items() + + for (i, j) in comp_red: + + if j != 1: + ind_str.append(i.name + str(int(j))) + else: + ind_str.append(i.name) + + eles.append(i.name) + + final_terms = ["".join(entry) for entry in permutations(ind_str)] + + pipeline = [ + { + "$search": { + "index": "formula_autocomplete", + "text": {"path": "formula_pretty", "query": final_terms}, + } + }, + { + "$project": { + "_id": 0, + "formula_pretty": 1, + "elements": 1, + "length": {"$strLenCP": "$formula_pretty"}, + } + }, + { + "$match": { + "length": {"$gte": len(final_terms[0])}, + "elements": {"$all": eles}, + } + }, + {"$limit": limit}, + {"$sort": {"length": 1}}, + {"$project": {"elements": 0, "length": 0}}, + ] + + return {"pipeline": pipeline} + + def ensure_indexes(self): + return [("formula_pretty", False)] + diff --git a/src/mp_api/routes/materials/resources.py b/src/mp_api/routes/materials/resources.py index 0382011a..b4bd5a0f 100644 --- a/src/mp_api/routes/materials/resources.py +++ b/src/mp_api/routes/materials/resources.py @@ -1,263 +1,70 @@ -from fastapi import HTTPException -from mp_api.core.resource import GetResource +from maggma.api.resource.read_resource import ReadOnlyResource +from maggma.api.resource.post_resource import PostOnlyResource +from maggma.api.resource.aggregation import AggregationResource -from emmet.core.material import MaterialsDoc +from emmet.core.material import MaterialsDoc +from mp_api.routes.materials.models.doc import FindStructure, FormulaAutocomplete -from mp_api.core.query_operator import ( +from maggma.api.query_operator import ( PaginationQuery, SparseFieldsQuery, - VersionQuery, SortQuery, + NumericQuery, ) + +from mp_api.core.settings import MAPISettings + from mp_api.routes.materials.query_operators import ( ElementsQuery, FormulaQuery, DeprecationQuery, - MinMaxQuery, SymmetryQuery, MultiTaskIDQuery, + FindStructureQuery, + FormulaAutoCompleteQuery, ) -from pymatgen.analysis.structure_matcher import StructureMatcher, ElementComparator -from pymatgen.core import Structure -from pymatgen.core import Composition from pymongo import MongoClient # type: ignore -from itertools import permutations -from fastapi import Query, Body - - -def materials_resource(materials_store, formula_autocomplete_store): - def custom_version_prep(self): - model_name = self.model.__name__ - - async def get_versions(): - f""" - Obtains the database versions for the data in {model_name} - - Returns: - A list of database versions one can use to query on - """ - - try: - conn = MongoClient(self.store.host, self.store.port) - db = conn[self.store.database] - if self.core.username != "": - db.authenticate(self.username, self.password) - - except AttributeError: - conn = MongoClient(self.store.uri) - db = conn[self.store.database] - - col_names = db.list_collection_names() - - d = [ - name.replace("_", ".")[15:] - for name in col_names - if "materials" in name - if name != "materials.core" - ] - - response = {"data": d} - - return response - - self.router.get( - "/versions/", - response_model_exclude_unset=True, - response_description=f"Get versions of {model_name}", - tags=self.tags, - )(get_versions) - - def custom_findstructure_prep(self): - model_name = self.model.__name__ - - async def find_structure( - structure: Structure = Body( - ..., title="Pymatgen structure object to query with", - ), - ltol: float = Query( - 0.2, title="Fractional length tolerance. Default is 0.2.", - ), - stol: float = Query( - 0.3, - title="Site tolerance. Defined as the fraction of the average free \ - length per atom := ( V / Nsites ) ** (1/3). Default is 0.3.", - ), - angle_tol: float = Query( - 5, title="Angle tolerance in degrees. Default is 5 degrees.", - ), - limit: int = Query( - 1, - title="Maximum number of matches to show. Defaults to 1, only showing the best match.", - ), - ): - """ - Obtains material structures that match a given input structure within some tolerance. - - Returns: - A list of Material IDs for materials with matched structures alongside the associated RMS values - """ - - try: - s = Structure.from_dict(structure.dict()) - except Exception: - raise HTTPException( - status_code=404, - detail="Body cannot be converted to a pymatgen structure object.", - ) - - m = StructureMatcher( - ltol=ltol, - stol=stol, - angle_tol=angle_tol, - primitive_cell=True, - scale=True, - attempt_supercell=False, - comparator=ElementComparator(), - ) - - crit = {"composition_reduced": dict(s.composition.to_reduced_dict)} - - self.store.connect() - - matches = [] - - for r in self.store.query( - criteria=crit, properties=["structure", "task_id"] - ): - s2 = Structure.from_dict(r["structure"]) - matched = m.fit(s, s2) - if matched: - rms = m.get_rms_dist(s, s2) - - matches.append( - { - "task_id": r["task_id"], - "normalized_rms_displacement": rms[0], - "max_distance_paired_sites": rms[1], - } - ) - - response = { - "data": sorted( - matches[:limit], - key=lambda x: ( - x["normalized_rms_displacement"], - x["max_distance_paired_sites"], - ), - ) - } - - return response - - self.router.post( - "/find_structure/", - response_model_exclude_unset=True, - response_description=f"Get matching structures using data from {model_name}", - tags=self.tags, - )(find_structure) - - def custom_autocomplete_prep(self): - async def formula_autocomplete( - text: str = Query( - ..., description="Text to run against formula autocomplete", - ), - limit: int = Query( - 10, description="Maximum number of matches to show. Defaults to 10", - ), - ): - store = formula_autocomplete_store - - try: - - comp = Composition(text) - - ind_str = [] - eles = [] - - if len(comp) == 1: - d = comp.get_integer_formula_and_factor() - - s = d[0] + str(int(d[1])) if d[1] != 1 else d[0] - - ind_str.append(s) - eles.append(d[0]) - else: - - comp_red = comp.reduced_composition.items() - - for (i, j) in comp_red: - - if j != 1: - ind_str.append(i.name + str(int(j))) - else: - ind_str.append(i.name) - - eles.append(i.name) - - final_terms = ["".join(entry) for entry in permutations(ind_str)] - - pipeline = [ - { - "$search": { - "index": "formula_autocomplete", - "text": {"path": "formula_pretty", "query": final_terms}, - } - }, - { - "$project": { - "_id": 0, - "formula_pretty": 1, - "elements": 1, - "length": {"$strLenCP": "$formula_pretty"}, - } - }, - { - "$match": { - "length": {"$gte": len(final_terms[0])}, - "elements": {"$all": eles}, - } - }, - {"$limit": limit}, - {"$sort": {"length": 1}}, - {"$project": {"elements": 0, "length": 0}}, - ] +def find_structure_resource(materials_store): + resource = PostOnlyResource( + materials_store, + FindStructure, + key_fields=["structure", "task_id"], + query_operators=[FindStructureQuery()], + tags=["Materials"], + sub_path="/find_structure/", + ) - store.connect() + return resource - data = list(store._collection.aggregate(pipeline, allowDiskUse=True)) - response = {"data": data} +def formula_autocomplete_resource(formula_autocomplete_store): + resource = AggregationResource( + formula_autocomplete_store, + FormulaAutocomplete, + pipeline_query_operator=FormulaAutoCompleteQuery(), + tags=["Materials"], + sub_path="/formula_autocomplete/", + ) - except Exception: - raise HTTPException( - status_code=404, - detail="Cannot autocomplete with provided formula.", - ) + return resource - return response - self.router.get( - "/formula_autocomplete/", - response_model_exclude_unset=True, - response_description="Get autocomplete results for a formula", - tags=self.tags, - )(formula_autocomplete) +def materials_resource(materials_store): - resource = GetResource( + resource = ReadOnlyResource( materials_store, MaterialsDoc, query_operators=[ - VersionQuery(), FormulaQuery(), ElementsQuery(), MultiTaskIDQuery(), SymmetryQuery(), DeprecationQuery(), - MinMaxQuery(), + NumericQuery(model=MaterialsDoc), SortQuery(), PaginationQuery(), SparseFieldsQuery( @@ -266,11 +73,7 @@ async def formula_autocomplete( ), ], tags=["Materials"], - custom_endpoint_funcs=[ - custom_version_prep, - custom_findstructure_prep, - custom_autocomplete_prep, - ], ) return resource + diff --git a/src/mp_api/routes/molecules/query_operators.py b/src/mp_api/routes/molecules/query_operators.py index 80496c8e..b51703ac 100644 --- a/src/mp_api/routes/molecules/query_operators.py +++ b/src/mp_api/routes/molecules/query_operators.py @@ -2,7 +2,8 @@ from fastapi import Query from pymatgen.core.periodic_table import Element from pymatgen.core import Composition -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from collections import defaultdict @@ -15,8 +16,7 @@ class MoleculeElementsQuery(QueryOperator): def query( self, elements: Optional[str] = Query( - None, - description="Query by elements in the material composition as a comma-separated list", + None, description="Query by elements in the material composition as a comma-separated list", ), ) -> STORE_PARAMS: @@ -39,33 +39,15 @@ class MoleculeBaseQuery(QueryOperator): def query( self, - nelements_max: Optional[float] = Query( - None, description="Maximum value for the number of elements.", - ), - nelements_min: Optional[float] = Query( - None, description="Minimum value for the number of elements.", - ), - EA_max: Optional[float] = Query( - None, description="Maximum value for the electron affinity in eV.", - ), - EA_min: Optional[float] = Query( - None, description="Minimum value for the electron affinity in eV.", - ), - IE_max: Optional[float] = Query( - None, description="Maximum value for the ionization energy in eV.", - ), - IE_min: Optional[float] = Query( - None, description="Minimum value for the ionization energy in eV.", - ), - charge_max: Optional[int] = Query( - None, description="Maximum value for the charge in +e.", - ), - charge_min: Optional[int] = Query( - None, description="Minimum value for the charge in +e.", - ), - pointgroup: Optional[str] = Query( - None, description="Point of the molecule in Schoenflies notation.", - ), + nelements_max: Optional[float] = Query(None, description="Maximum value for the number of elements.",), + nelements_min: Optional[float] = Query(None, description="Minimum value for the number of elements.",), + EA_max: Optional[float] = Query(None, description="Maximum value for the electron affinity in eV.",), + EA_min: Optional[float] = Query(None, description="Minimum value for the electron affinity in eV.",), + IE_max: Optional[float] = Query(None, description="Maximum value for the ionization energy in eV.",), + IE_min: Optional[float] = Query(None, description="Minimum value for the ionization energy in eV.",), + charge_max: Optional[int] = Query(None, description="Maximum value for the charge in +e.",), + charge_min: Optional[int] = Query(None, description="Minimum value for the charge in +e.",), + pointgroup: Optional[str] = Query(None, description="Point of the molecule in Schoenflies notation.",), smiles: Optional[str] = Query( None, description="The simplified molecular input line-entry system (SMILES) \ @@ -113,10 +95,7 @@ class MoleculeFormulaQuery(QueryOperator): """ def query( - self, - formula: Optional[str] = Query( - None, description="Chemical formula of the molecule.", - ), + self, formula: Optional[str] = Query(None, description="Chemical formula of the molecule.",), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict diff --git a/src/mp_api/routes/molecules/resources.py b/src/mp_api/routes/molecules/resources.py index 2da0849c..bcd02b86 100644 --- a/src/mp_api/routes/molecules/resources.py +++ b/src/mp_api/routes/molecules/resources.py @@ -1,7 +1,7 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.molecules.models import MoleculesDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.molecules.query_operators import ( MoleculeBaseQuery, MoleculeElementsQuery, @@ -11,7 +11,7 @@ def molecules_resource(molecules_store): - resource = GetResource( + resource = ReadOnlyResource( molecules_store, MoleculesDoc, query_operators=[ diff --git a/src/mp_api/routes/mpcomplete/models.py b/src/mp_api/routes/mpcomplete/models.py index c975f046..6ab25e88 100644 --- a/src/mp_api/routes/mpcomplete/models.py +++ b/src/mp_api/routes/mpcomplete/models.py @@ -1,5 +1,7 @@ -from pydantic import BaseModel, Field +from pydantic import Field +from pydantic.main import BaseModel from pymatgen.core.structure import Structure +from enum import Enum class MPCompleteDoc(BaseModel): @@ -7,10 +9,6 @@ class MPCompleteDoc(BaseModel): Defines data for MPComplete structure submissions """ - snl_id: str = Field( - None, title="SNL ID", description="ID for the submission.", - ) - structure: Structure = Field( None, title="Submitted structure", @@ -25,6 +23,14 @@ class MPCompleteDoc(BaseModel): None, title="Public email", description="Public email of submitter.", ) - comment: str = Field( - None, title="Submission comment", description="User comment for submission.", - ) + +class MPCompleteDataStatus(Enum): + """ + Submission status for MPComplete data + """ + + submitted = "SUBMITTED" + pending = "PENDING" + running = "RUNNING" + error = "ERROR" + complete = "COMPLETE" diff --git a/src/mp_api/routes/mpcomplete/query_operator.py b/src/mp_api/routes/mpcomplete/query_operator.py index 2f232082..e121eefe 100644 --- a/src/mp_api/routes/mpcomplete/query_operator.py +++ b/src/mp_api/routes/mpcomplete/query_operator.py @@ -1,7 +1,7 @@ from fastapi import Query, Body from pymatgen.core.structure import Structure -from mp_api.core.utils import STORE_PARAMS -from mp_api.core.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS +from maggma.api.query_operator import QueryOperator from uuid import uuid4 @@ -13,21 +13,16 @@ def query( structure: Structure = Body(..., title="Structure submission"), public_name: str = Query(..., title="Public name"), public_email: str = Query(..., title="Public email"), - comment: str = Query(..., title="Submission comment"), ) -> STORE_PARAMS: self.structure = structure self.public_name = public_name self.public_email = public_email - self.comment = comment - self.snl_id = str(uuid4()) crit = { - "snl_id": self.snl_id, "structure": structure, "public_email": public_email, "public_name": public_name, - "comment": comment, } return {"criteria": crit} @@ -36,12 +31,33 @@ def post_process(self, written): d = [ { - "snl_id": self.snl_id, "structure": self.structure, "public_email": self.public_email, "public_name": self.public_name, - "comment": self.comment, } ] return d + + +class MPCompleteGetQuery(QueryOperator): + """Query operators for querying on MPComplete data""" + + def query( + self, + public_name: str = Query(None, title="Public name"), + public_email: str = Query(None, title="Public email"), + ) -> STORE_PARAMS: + + self.public_name = public_name + self.public_email = public_email + + crit = {} + + if public_name is not None: + crit.update({"public_name": public_name}) + + if public_email is not None: + crit.update({"public_name": public_email}) + + return {"criteria": crit} diff --git a/src/mp_api/routes/mpcomplete/resources.py b/src/mp_api/routes/mpcomplete/resources.py index 09e3fd99..4f8f233b 100644 --- a/src/mp_api/routes/mpcomplete/resources.py +++ b/src/mp_api/routes/mpcomplete/resources.py @@ -1,14 +1,22 @@ -from mp_api.core.resource import ConsumerPostResource -from mp_api.routes.mpcomplete.models import MPCompleteDoc -from mp_api.routes.mpcomplete.query_operator import MPCompletePostQuery +from maggma.api.resource import SubmissionResource +from maggma.api.query_operator import PaginationQuery +from mp_api.routes.mpcomplete.models import MPCompleteDoc, MPCompleteDataStatus +from mp_api.routes.mpcomplete.query_operator import ( + MPCompletePostQuery, + MPCompleteGetQuery, +) def mpcomplete_resource(mpcomplete_store): - resource = ConsumerPostResource( + resource = SubmissionResource( mpcomplete_store, MPCompleteDoc, - query_operators=[MPCompletePostQuery()], + post_query_operators=[MPCompletePostQuery()], + get_query_operators=[MPCompleteGetQuery(), PaginationQuery()], tags=["MPComplete"], + state_enum=MPCompleteDataStatus, + default_state=MPCompleteDataStatus.submitted.value, + calculate_submission_id=True, include_in_schema=True, ) diff --git a/src/mp_api/routes/phonon/models.py b/src/mp_api/routes/phonon/models.py index 3d8ffb49..d96b8c2a 100644 --- a/src/mp_api/routes/phonon/models.py +++ b/src/mp_api/routes/phonon/models.py @@ -93,7 +93,7 @@ class PhononImgDoc(BaseModel): def last_updated_dict_ok(cls, v): return MontyDecoder().process_decoded(v) - # Make sure that the datetime field is properly formatted + # Make sure that the plot field is properly formatted @validator("plot", pre=True) def plot_bytes_ok(cls, v): return str(v) diff --git a/src/mp_api/routes/phonon/query_operators.py b/src/mp_api/routes/phonon/query_operators.py new file mode 100644 index 00000000..fba693e6 --- /dev/null +++ b/src/mp_api/routes/phonon/query_operators.py @@ -0,0 +1,24 @@ +import io +from emmet.core.mpid import MPID + +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS + +from fastapi import Path +from fastapi.responses import StreamingResponse + + +class PhononImgQuery(QueryOperator): + """ + Method to generate a query on phonon image data. + """ + + def query( + self, + task_id: MPID = Path( + ..., + description="The calculation (task) ID associated with the data object", + ), + ) -> STORE_PARAMS: + + return {"criteria": {"task_id": str(task_id)}} diff --git a/src/mp_api/routes/phonon/resources.py b/src/mp_api/routes/phonon/resources.py index fcdce97b..de710916 100644 --- a/src/mp_api/routes/phonon/resources.py +++ b/src/mp_api/routes/phonon/resources.py @@ -1,16 +1,12 @@ -import io -from fastapi.exceptions import HTTPException -from fastapi.param_functions import Path -from fastapi.responses import StreamingResponse - -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.phonon.models import PhononBSDoc, PhononImgDoc +from mp_api.routes.phonon.query_operators import PhononImgQuery -from mp_api.core.query_operator import PaginationQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SparseFieldsQuery def phonon_bs_resource(phonon_bs_store): - resource = GetResource( + resource = ReadOnlyResource( phonon_bs_store, PhononBSDoc, query_operators=[ @@ -25,60 +21,15 @@ def phonon_bs_resource(phonon_bs_store): def phonon_img_resource(phonon_img_store): - def phonon_img_prep(self): - async def get_image( - task_id: str = Path( - ..., alias="task_id", title="Materials Project ID of the material.", - ), - ): - """ - Obtains a phonon band structure image if available. - - Returns: - Phonon band structure image. - """ - - crit = {"task_id": task_id} - - self.store.connect() - - img = self.store.query_one(criteria=crit, properties=["plot"]) - - if img is None: - raise HTTPException( - status_code=404, detail=f"No image found for {task_id}.", - ) - - else: - img = img["plot"] - - response = StreamingResponse( - io.BytesIO(img), - media_type="img/png", - headers={ - "Content-Disposition": 'inline; filename="{}_phonon_bs.png"'.format( - task_id - ) - }, - ) - - return response - - self.router.get( - "/{task_id}/", - response_model_exclude_unset=True, - response_description="Get phonon band structure image.", - tags=self.tags, - )(get_image) - resource = GetResource( + resource = ReadOnlyResource( phonon_img_store, PhononImgDoc, - # query_operators=[PaginationQuery()], tags=["Phonon"], - custom_endpoint_funcs=[phonon_img_prep], enable_default_search=False, - enable_get_by_key=False, + enable_get_by_key=True, + key_fields=["plot", "task_id", "last_updated"], + sub_path="/image/", ) return resource diff --git a/src/mp_api/routes/piezo/query_operators.py b/src/mp_api/routes/piezo/query_operators.py index 49009500..c4806da8 100644 --- a/src/mp_api/routes/piezo/query_operators.py +++ b/src/mp_api/routes/piezo/query_operators.py @@ -1,6 +1,7 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from collections import defaultdict diff --git a/src/mp_api/routes/piezo/resources.py b/src/mp_api/routes/piezo/resources.py index 330c32e4..b9e59bfd 100644 --- a/src/mp_api/routes/piezo/resources.py +++ b/src/mp_api/routes/piezo/resources.py @@ -1,12 +1,12 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.piezo.models import PiezoDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.piezo.query_operators import PiezoelectricQuery def piezo_resource(piezo_store): - resource = GetResource( + resource = ReadOnlyResource( piezo_store, PiezoDoc, query_operators=[ diff --git a/src/mp_api/routes/robocrys/query_operators.py b/src/mp_api/routes/robocrys/query_operators.py new file mode 100644 index 00000000..a3b61d44 --- /dev/null +++ b/src/mp_api/routes/robocrys/query_operators.py @@ -0,0 +1,57 @@ +from fastapi import HTTPException, Query +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS + + +class RoboTextSearchQuery(QueryOperator): + """ + Method to generate a robocrystallographer text search query + """ + + def query( + self, + keywords: str = Query( + ..., + description="Comma delimited string keywords to search robocrystallographer description text with", + ), + skip: int = Query(0, description="Number of entries to skip in the search"), + limit: int = Query( + 100, + description="Max number of entries to return in a single query. Limited to 100", + ), + ) -> STORE_PARAMS: + + if not keywords.strip(): + raise HTTPException( + status_code=400, detail="Must provide search keywords.", + ) + + pipeline = [ + { + "$search": { + "index": "description", + "regex": { + "query": [word.strip() for word in keywords.split(",") if word], + "path": "description", + "allowAnalyzedField": True, + }, + } + }, + { + "$project": { + "_id": 0, + "task_id": 1, + "description": 1, + "condensed_structure": 1, + "last_updates": 1, + "search_score": {"$meta": "searchScore"}, + } + }, + {"$sort": {"search_score": -1}}, + {"$skip": skip}, + {"$limit": limit}, + ] + return {"pipeline": pipeline} + + def ensure_indexes(self): + return [("description", False)] diff --git a/src/mp_api/routes/robocrys/resources.py b/src/mp_api/routes/robocrys/resources.py index 7526da55..c6a25c9a 100644 --- a/src/mp_api/routes/robocrys/resources.py +++ b/src/mp_api/routes/robocrys/resources.py @@ -1,70 +1,27 @@ -from fastapi import Query -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource +from maggma.api.resource.aggregation import AggregationResource from mp_api.routes.robocrys.models import RobocrysDoc +from mp_api.routes.robocrys.query_operators import RoboTextSearchQuery def robo_resource(robo_store): - def custom_robo_prep(self): - async def query_robo_text( - keywords: str = Query( - ..., - description="Comma delimited string keywords to search robocrystallographer description text with", - ), - skip: int = Query(0, description="Number of entries to skip in the search"), - limit: int = Query( - 100, - description="Max number of entries to return in a single query. Limited to 100", - ), - ): - - pipeline = [ - { - "$search": { - "index": "description", - "regex": { - "query": [word + ".*" for word in keywords.split(",")], - "path": "description", - "allowAnalyzedField": True, - }, - } - }, - { - "$project": { - "_id": 0, - "task_id": 1, - "description": 1, - "condensed_structure": 1, - "last_updates": 1, - "search_score": {"$meta": "searchScore"}, - } - }, - {"$sort": {"search_score": -1}}, - {"$skip": skip}, - {"$limit": limit}, - ] - - self.store.connect() - - data = list(self.store._collection.aggregate(pipeline, allowDiskUse=True)) - - response = {"data": data} + resource = ReadOnlyResource( + robo_store, + RobocrysDoc, + tags=["Robocrystallographer"], + enable_default_search=False, + ) - return response + return resource - self.router.get( - "/text_search/", - response_model=self.response_model, - response_model_exclude_unset=True, - response_description="Find robocrystallographer documents through text search.", - tags=self.tags, - )(query_robo_text) - resource = GetResource( +def robo_search_resource(robo_store): + resource = AggregationResource( robo_store, RobocrysDoc, + pipeline_query_operator=RoboTextSearchQuery(), tags=["Robocrystallographer"], - custom_endpoint_funcs=[custom_robo_prep], - enable_default_search=False, + sub_path="/text_search/", ) return resource diff --git a/src/mp_api/routes/search/query_operators.py b/src/mp_api/routes/search/query_operators.py index 0f9c0c7b..10ce8355 100644 --- a/src/mp_api/routes/search/query_operators.py +++ b/src/mp_api/routes/search/query_operators.py @@ -1,9 +1,14 @@ from enum import Enum -from typing import Optional +from typing import Optional, Literal from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from mp_api.routes.magnetism.models import MagneticOrderingEnum +from mp_api.routes.search.models import SearchStats + +from scipy.stats import gaussian_kde +import numpy as np from collections import defaultdict @@ -86,114 +91,6 @@ def ensure_indexes(self): return [("is_stable", False)] -class SearchElasticityQuery(QueryOperator): - """ - Method to generate a query for ranges of elasticity data in search docs - """ - - def query( - self, - k_voigt_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt average of the bulk modulus in GPa.", - ), - k_voigt_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt average of the bulk modulus in GPa.", - ), - k_reuss_max: Optional[float] = Query( - None, - description="Maximum value for the Reuss average of the bulk modulus in GPa.", - ), - k_reuss_min: Optional[float] = Query( - None, - description="Minimum value for the Reuss average of the bulk modulus in GPa.", - ), - k_vrh_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt-Reuss-Hill average of the bulk modulus in GPa.", - ), - k_vrh_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt-Reuss-Hill average of the bulk modulus in GPa.", - ), - g_voigt_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt average of the shear modulus in GPa.", - ), - g_voigt_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt average of the shear modulus in GPa.", - ), - g_reuss_max: Optional[float] = Query( - None, - description="Maximum value for the Reuss average of the shear modulus in GPa.", - ), - g_reuss_min: Optional[float] = Query( - None, - description="Minimum value for the Reuss average of the shear modulus in GPa.", - ), - g_vrh_max: Optional[float] = Query( - None, - description="Maximum value for the Voigt-Reuss-Hill average of the shear modulus in GPa.", - ), - g_vrh_min: Optional[float] = Query( - None, - description="Minimum value for the Voigt-Reuss-Hill average of the shear modulus in GPa.", - ), - elastic_anisotropy_max: Optional[float] = Query( - None, description="Maximum value for the elastic anisotropy.", - ), - elastic_anisotropy_min: Optional[float] = Query( - None, description="Maximum value for the elastic anisotropy.", - ), - poisson_max: Optional[float] = Query( - None, description="Maximum value for Poisson's ratio.", - ), - poisson_min: Optional[float] = Query( - None, description="Minimum value for Poisson's ratio.", - ), - ) -> STORE_PARAMS: - - crit = defaultdict(dict) # type: dict - - d = { - "k_voigt": [k_voigt_min, k_voigt_max], - "k_reuss": [k_reuss_min, k_reuss_max], - "k_vrh": [k_vrh_min, k_vrh_max], - "g_voigt": [g_voigt_min, g_voigt_max], - "g_reuss": [g_reuss_min, g_reuss_max], - "g_vrh": [g_vrh_min, g_vrh_max], - "universal_anisotropy": [elastic_anisotropy_min, elastic_anisotropy_max], - "homogeneous_poisson": [poisson_min, poisson_max], - } - - for entry in d: - if d[entry][0]: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1]: - crit[entry]["$lte"] = d[entry][1] - - return {"criteria": crit} - - def ensure_indexes(self): - keys = [ - key - for key in self._keys_from_query() - if "anisotropy" not in key and "poisson" not in key - ] - - indexes = [] - for key in keys: - if "_min" in key: - key = key.replace("_min", "") - indexes.append((key, False)) - indexes.append(("universal_anisotropy", False)) - indexes.append(("homogeneous_poisson", False)) - return indexes - - class SearchMagneticQuery(QueryOperator): """ Method to generate a query for magnetic data in search docs. @@ -204,126 +101,17 @@ def query( ordering: Optional[MagneticOrderingEnum] = Query( None, description="Magnetic ordering of the material." ), - total_magnetization_max: Optional[float] = Query( - None, description="Maximum value for the total magnetization.", - ), - total_magnetization_min: Optional[float] = Query( - None, description="Minimum value for the total magnetization.", - ), - total_magnetization_normalized_vol_max: Optional[float] = Query( - None, - description="Maximum value for the total magnetization normalized with volume.", - ), - total_magnetization_normalized_vol_min: Optional[float] = Query( - None, - description="Minimum value for the total magnetization normalized with volume.", - ), - total_magnetization_normalized_formula_units_max: Optional[float] = Query( - None, - description="Maximum value for the total magnetization normalized with formula units.", - ), - total_magnetization_normalized_formula_units_min: Optional[float] = Query( - None, - description="Minimum value for the total magnetization normalized with formula units.", - ), ) -> STORE_PARAMS: crit = defaultdict(dict) # type: dict - d = { - "total_magnetization": [total_magnetization_min, total_magnetization_max], - "total_magnetization_normalized_vol": [ - total_magnetization_normalized_vol_min, - total_magnetization_normalized_vol_max, - ], - "total_magnetization_normalized_formula_units": [ - total_magnetization_normalized_formula_units_min, - total_magnetization_normalized_formula_units_max, - ], - } # type: dict - - for entry in d: - if d[entry][0]: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1]: - crit[entry]["$lte"] = d[entry][1] - if ordering: crit["ordering"] = ordering.value return {"criteria": crit} def ensure_indexes(self): - keys = [ - "total_magnetization", - "total_magnetization_normalized_vol", - "total_magnetization_normalized_formula_units", - ] - return [(key, False) for key in keys] - - -class SearchDielectricPiezoQuery(QueryOperator): - """ - Method to generate a query for ranges of dielectric and piezo data in search docs - """ - - def query( - self, - e_total_max: Optional[float] = Query( - None, description="Maximum value for the total dielectric constant.", - ), - e_total_min: Optional[float] = Query( - None, description="Minimum value for the total dielectric constant.", - ), - e_ionic_max: Optional[float] = Query( - None, description="Maximum value for the ionic dielectric constant.", - ), - e_ionic_min: Optional[float] = Query( - None, description="Minimum value for the ionic dielectric constant.", - ), - e_static_max: Optional[float] = Query( - None, description="Maximum value for the static dielectric constant.", - ), - e_static_min: Optional[float] = Query( - None, description="Minimum value for the static dielectric constant.", - ), - n_max: Optional[float] = Query( - None, description="Maximum value for the refractive index.", - ), - n_min: Optional[float] = Query( - None, description="Minimum value for the refractive index.", - ), - piezo_modulus_max: Optional[float] = Query( - None, description="Maximum value for the piezoelectric modulus in C/m².", - ), - piezo_modulus_min: Optional[float] = Query( - None, description="Minimum value for the piezoelectric modulus in C/m².", - ), - ) -> STORE_PARAMS: - - crit = defaultdict(dict) # type: dict - - d = { - "e_total": [e_total_min, e_total_max], - "e_ionic": [e_ionic_min, e_ionic_max], - "e_static": [e_static_min, e_static_max], - "n": [n_min, n_max], - "e_ij_max": [piezo_modulus_min, piezo_modulus_max], - } - - for entry in d: - if d[entry][0]: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1]: - crit[entry]["$lte"] = d[entry][1] - - return {"criteria": crit} - - def ensure_indexes(self): - keys = ["e_total", "e_ionic", "e_static", "n", "e_ij_max"] - return [(key, False) for key in keys] + return [("ordering", False)] class SearchIsTheoreticalQuery(QueryOperator): @@ -349,6 +137,106 @@ def ensure_indexes(self): return [("theoretical", False)] +class SearchStatsQuery(QueryOperator): + """ + Method to generate a query on search stats data + """ + + def __init__(self, search_doc): + valid_numeric_fields = tuple( + sorted(k for k, v in search_doc.__fields__.items() if v.type_ == float) + ) + + def query( + field: Literal[valid_numeric_fields] = Query( # type: ignore + valid_numeric_fields[0], + title=f"SearchDoc field to query on, must be a numerical field, " + f"choose from: {', '.join(valid_numeric_fields)}", + ), + num_samples: Optional[int] = Query( + None, title="If specified, will only sample this number of documents.", + ), + min_val: Optional[float] = Query( + None, + title="If specified, will only consider documents with field values " + "greater than or equal to this minimum value.", + ), + max_val: Optional[float] = Query( + None, + title="If specified, will only consider documents with field values " + "less than or equal to this minimum value.", + ), + num_points: int = Query( + 100, title="The number of values in the returned distribution." + ), + ) -> STORE_PARAMS: + + self.num_points = num_points + self.min_val = min_val + self.max_val = max_val + + if min_val or max_val: + pipeline = [{"$match": {field: {}}}] # type: list + if min_val: + pipeline[0]["$match"][field]["$gte"] = min_val + if max_val: + pipeline[0]["$match"][field]["$lte"] = max_val + else: + pipeline = [] + + if num_samples: + pipeline.append({"$sample": {"size": num_samples}}) + + pipeline.append({"$project": {field: 1, "_id": 0}}) + + return {"pipeline": pipeline} + + self.query = query + + def query(self): + " Stub query function for abstract class " + pass + + def post_process(self, docs): + + if docs: + field = list(docs[0].keys())[0] + + num_points = self.num_points + min_val = self.min_val + max_val = self.max_val + num_samples = len(docs) + + values = [d[field] for d in docs] + if not min_val: + min_val = min(values) + if not max_val: + max_val = max(values) + + kernel = gaussian_kde(values) + + distribution = list( + kernel( + np.arange(min_val, max_val, step=(max_val - min_val) / num_points,) # type: ignore + ) + ) + + median = float(np.median(values)) + mean = float(np.mean(values)) + + response = SearchStats( + field=field, + num_samples=num_samples, + min=min_val, + max=max_val, + distribution=distribution, + median=median, + mean=mean, + ) + + return [response] + + # TODO: # XAS and GB sub doc query operators # Add weighted work function to data diff --git a/src/mp_api/routes/search/resources.py b/src/mp_api/routes/search/resources.py index 97bf7a0d..960609d7 100644 --- a/src/mp_api/routes/search/resources.py +++ b/src/mp_api/routes/search/resources.py @@ -10,7 +10,6 @@ DeprecationQuery, ElementsQuery, FormulaQuery, - MinMaxQuery, SymmetryQuery, ) from mp_api.routes.search.models import SearchStats @@ -30,125 +29,37 @@ def search_resource(search_store): - def generate_stats_prep(self): - model_name = self.model.__name__ - - # we can only generate statistics for fields that return numbers - valid_numeric_fields = tuple( - sorted(k for k, v in SearchDoc().__fields__.items() if v.type_ == float) - ) - - async def generate_stats( - field: Literal[valid_numeric_fields] = Query( - valid_numeric_fields[0], - title=f"SearchDoc field to query on, must be a numerical field, " - f"choose from: {', '.join(valid_numeric_fields)}", - ), - num_samples: Optional[int] = Query( - None, title="If specified, will only sample this number of documents.", - ), - min_val: Optional[float] = Query( - None, - title="If specified, will only consider documents with field values " - "greater than or equal to this minimum value.", - ), - max_val: Optional[float] = Query( - None, - title="If specified, will only consider documents with field values " - "less than or equal to this minimum value.", - ), - num_points: int = Query( - 100, title="The number of values in the returned distribution." - ), - ): - """ - Generate statistics for a given numerical field specified in SearchDoc. - - Returns: - A SearchStats object. - """ - - self.store.connect() - - if min_val or max_val: - pipeline = [{"$match": {field: {}}}] # type: list - if min_val: - pipeline[0]["$match"][field]["$gte"] = min_val - if max_val: - pipeline[0]["$match"][field]["$lte"] = max_val - else: - pipeline = [] - - if num_samples: - pipeline.append({"$sample": {"size": num_samples}}) - - pipeline.append({"$project": {field: 1}}) - - values = [ - d[field] - for d in self.store._collection.aggregate(pipeline, allowDiskUse=True) - ] - if not min_val: - min_val = min(values) - if not max_val: - max_val = max(values) - - kernel = gaussian_kde(values) - - distribution = list( - kernel( - np.arange(min_val, max_val, step=(max_val - min_val) / num_points,) # type: ignore - ) - ) - - median = float(np.median(values)) - mean = float(np.mean(values)) - - response = SearchStats( - field=field, - num_samples=num_samples, - min=min_val, - max=max_val, - distribution=distribution, - median=median, - mean=mean, - ) - - return response - - self.router.get( - "/generate_statistics/", - response_model=SearchStats, - response_model_exclude_unset=True, - response_description=f"Generate statistics for a given field as {model_name}", - tags=self.tags, - )(generate_stats) - - resource = GetResource( + resource = ReadOnlyResource( search_store, SearchDoc, query_operators=[ MaterialIDsSearchQuery(), FormulaQuery(), ElementsQuery(), - MinMaxQuery(), SymmetryQuery(), - ThermoEnergyQuery(), SearchIsStableQuery(), SearchIsTheoreticalQuery(), - ESSummaryDataQuery(), - SearchElasticityQuery(), - SearchDielectricPiezoQuery(), - SurfaceMinMaxQuery(), SearchMagneticQuery(), + NumericQuery(model=SearchDoc, excluded_fields=["composition"]), HasPropsQuery(), DeprecationQuery(), SortQuery(), PaginationQuery(), SparseFieldsQuery(SearchDoc, default_fields=["material_id"]), ], - custom_endpoint_funcs=[generate_stats_prep], tags=["Search"], ) return resource + + +def search_stats_resource(search_store): + resource = AggregationResource( + search_store, + SearchStats, + pipeline_query_operator=SearchStatsQuery(SearchDoc), + tags=["Search"], + sub_path="/stats/", + ) + + return resource diff --git a/src/mp_api/routes/similarity/resources.py b/src/mp_api/routes/similarity/resources.py index 868fb6a5..29d4074b 100644 --- a/src/mp_api/routes/similarity/resources.py +++ b/src/mp_api/routes/similarity/resources.py @@ -1,16 +1,13 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.similarity.models import SimilarityDoc -from mp_api.core.query_operator import PaginationQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SparseFieldsQuery def similarity_resource(similarity_store): - resource = GetResource( + resource = ReadOnlyResource( similarity_store, SimilarityDoc, - query_operators=[ - PaginationQuery(), - SparseFieldsQuery(SimilarityDoc, default_fields=["task_id"]), - ], + query_operators=[PaginationQuery(), SparseFieldsQuery(SimilarityDoc, default_fields=["task_id"]),], tags=["Similarity"], enable_default_search=False, ) diff --git a/src/mp_api/routes/substrates/query_operators.py b/src/mp_api/routes/substrates/query_operators.py index bf831a72..12801ae5 100644 --- a/src/mp_api/routes/substrates/query_operators.py +++ b/src/mp_api/routes/substrates/query_operators.py @@ -1,6 +1,7 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from collections import defaultdict @@ -12,15 +13,6 @@ class SubstrateStructureQuery(QueryOperator): def query( self, - film_id: Optional[str] = Query( - None, description="Materials Project ID of the film material.", - ), - substrate_id: Optional[str] = Query( - None, description="Materials Project ID of the substrate material.", - ), - substrate_formula: Optional[str] = Query( - None, description="Reduced formula of the substrate material.", - ), film_orientation: Optional[str] = Query( None, description="Comma separated integers defining the film surface orientation.", @@ -33,15 +25,6 @@ def query( crit = defaultdict(dict) # type: dict - if film_id: - crit["film_id"] = film_id - - if substrate_id: - crit["sub_id"] = substrate_id - - if substrate_formula: - crit["sub_form"] = substrate_formula - if film_orientation: crit["film_orient"] = film_orientation.replace(",", " ") diff --git a/src/mp_api/routes/substrates/resources.py b/src/mp_api/routes/substrates/resources.py index 03200e25..83433288 100644 --- a/src/mp_api/routes/substrates/resources.py +++ b/src/mp_api/routes/substrates/resources.py @@ -1,17 +1,26 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.substrates.models import SubstratesDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery -from mp_api.routes.substrates.query_operators import SubstrateStructureQuery, EnergyAreaQuery +from maggma.api.query_operator import ( + PaginationQuery, + SortQuery, + SparseFieldsQuery, + NumericQuery, + StringQueryOperator, +) +from mp_api.routes.substrates.query_operators import SubstrateStructureQuery def substrates_resource(substrates_store): - resource = GetResource( + resource = ReadOnlyResource( substrates_store, SubstratesDoc, query_operators=[ SubstrateStructureQuery(), - EnergyAreaQuery(), + NumericQuery(model=SubstratesDoc), + StringQueryOperator( + model=SubstratesDoc, excluded_fields=["film_orient", "orient"] + ), SortQuery(), PaginationQuery(), SparseFieldsQuery(SubstratesDoc, default_fields=["film_id", "sub_id"]), diff --git a/src/mp_api/routes/surface_properties/query_operators.py b/src/mp_api/routes/surface_properties/query_operators.py index 10ca6b90..75113ca2 100644 --- a/src/mp_api/routes/surface_properties/query_operators.py +++ b/src/mp_api/routes/surface_properties/query_operators.py @@ -1,75 +1,7 @@ from typing import Optional from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator - -from collections import defaultdict - - -class SurfaceMinMaxQuery(QueryOperator): - """ - Method to generate a query for ranges of surface energy, anisotropy, and shape factor. - """ - - def query( - self, - weighted_surface_energy_max: Optional[float] = Query( - None, description="Maximum value for the weighted surface energy in J/m².", - ), - weighted_surface_energy_min: Optional[float] = Query( - None, description="Minimum value for the weighted surface energy in J/m².", - ), - weighted_work_function_max: Optional[float] = Query( - None, description="Maximum value for the weighted work function in eV.", - ), - weighted_work_function_min: Optional[float] = Query( - None, description="Minimum value for the weighted work function in eV.", - ), - surface_anisotropy_max: Optional[float] = Query( - None, description="Maximum value for the surface energy anisotropy.", - ), - surface_anisotropy_min: Optional[float] = Query( - None, description="Minimum value for the surface energy anisotropy.", - ), - shape_factor_max: Optional[float] = Query( - None, description="Maximum value for the shape factor.", - ), - shape_factor_min: Optional[float] = Query( - None, description="Minimum value for the shape factor.", - ), - ) -> STORE_PARAMS: - - crit = defaultdict(dict) # type: dict - - d = { - "weighted_surface_energy": [ - weighted_surface_energy_min, - weighted_surface_energy_max, - ], - "weighted_work_function": [ - weighted_work_function_min, - weighted_work_function_max, - ], - "surface_anisotropy": [surface_anisotropy_min, surface_anisotropy_max], - "shape_factor": [shape_factor_min, shape_factor_max], - } - - for entry in d: - if d[entry][0]: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1]: - crit[entry]["$lte"] = d[entry][1] - - return {"criteria": crit} - - def ensure_indexes(self): - keys = [ - "weighted_surface_energy", - "weighted_work_function", - "surface_anisotropy", - "shape_factor", - ] - return [(key, False) for key in keys] +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS class ReconstructedQuery(QueryOperator): diff --git a/src/mp_api/routes/surface_properties/resources.py b/src/mp_api/routes/surface_properties/resources.py index 9c013856..2f43c32a 100644 --- a/src/mp_api/routes/surface_properties/resources.py +++ b/src/mp_api/routes/surface_properties/resources.py @@ -1,19 +1,17 @@ -from mp_api.core.resource import GetResource +from maggma.api.query_operator.dynamic import NumericQuery +from maggma.api.resource import ReadOnlyResource from mp_api.routes.surface_properties.models import SurfacePropDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery -from mp_api.routes.surface_properties.query_operators import ( - SurfaceMinMaxQuery, - ReconstructedQuery, -) +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from mp_api.routes.surface_properties.query_operators import ReconstructedQuery def surface_props_resource(surface_prop_store): - resource = GetResource( + resource = ReadOnlyResource( surface_prop_store, SurfacePropDoc, query_operators=[ - SurfaceMinMaxQuery(), + NumericQuery(model=SurfacePropDoc), ReconstructedQuery(), SortQuery(), PaginationQuery(), diff --git a/src/mp_api/routes/tasks/query_operators.py b/src/mp_api/routes/tasks/query_operators.py index faac77a5..d5ec1f4d 100644 --- a/src/mp_api/routes/tasks/query_operators.py +++ b/src/mp_api/routes/tasks/query_operators.py @@ -1,4 +1,5 @@ -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from mp_api.routes.tasks.utils import calcs_reversed_to_trajectory from fastapi import Query from typing import Optional @@ -70,7 +71,7 @@ class DeprecationQuery(QueryOperator): def query( self, task_ids: str = Query( - None, description="Comma-separated list of task_ids to query on" + ..., description="Comma-separated list of task_ids to query on" ), ) -> STORE_PARAMS: diff --git a/src/mp_api/routes/tasks/resources.py b/src/mp_api/routes/tasks/resources.py index 75e416eb..3a31330c 100644 --- a/src/mp_api/routes/tasks/resources.py +++ b/src/mp_api/routes/tasks/resources.py @@ -1,7 +1,7 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.tasks.models import DeprecationDoc, TaskDoc, TrajectoryDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.tasks.query_operators import ( MultipleTaskIDsQuery, TrajectoryQuery, @@ -14,7 +14,7 @@ def task_resource(task_store): - resource = GetResource( + resource = ReadOnlyResource( task_store, TaskDoc, query_operators=[ @@ -34,25 +34,27 @@ def task_resource(task_store): def task_deprecation_resource(materials_store): - resource = GetResource( + resource = ReadOnlyResource( materials_store, DeprecationDoc, query_operators=[DeprecationQuery(), PaginationQuery()], tags=["Tasks"], enable_get_by_key=False, enable_default_search=True, + sub_path="/deprecation/", ) return resource def trajectory_resource(task_store): - resource = GetResource( + resource = ReadOnlyResource( task_store, TrajectoryDoc, query_operators=[TrajectoryQuery(), PaginationQuery()], key_fields=["task_id", "calcs_reversed"], tags=["Tasks"], + sub_path="/trajectory/", ) return resource diff --git a/src/mp_api/routes/thermo/query_operators.py b/src/mp_api/routes/thermo/query_operators.py index af04c900..0e9917d8 100644 --- a/src/mp_api/routes/thermo/query_operators.py +++ b/src/mp_api/routes/thermo/query_operators.py @@ -2,7 +2,8 @@ from collections import defaultdict from fastapi import Query from pymatgen.core import Element -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS class ThermoChemicalQuery(QueryOperator): @@ -68,84 +69,3 @@ def query( def ensure_indexes(self): keys = self._keys_from_query() return [(key, False) for key in keys] - - -class ThermoEnergyQuery(QueryOperator): - """ - Method to generate a query for ranges of thermo energy data - """ - - def query( - self, - energy_per_atom_max: Optional[float] = Query( - None, - description="Maximum value for the corrected total energy in eV/atom.", - ), - energy_per_atom_min: Optional[float] = Query( - None, - description="Minimum value for the corrected total energy in eV/atom.", - ), - formation_energy_per_atom_max: Optional[float] = Query( - None, description="Maximum value for the formation energy in eV/atom.", - ), - formation_energy_per_atom_min: Optional[float] = Query( - None, description="Minimum value for the formation energy in eV/atom.", - ), - energy_above_hull_max: Optional[float] = Query( - None, description="Maximum value for the energy above the hull in eV/atom.", - ), - energy_above_hull_min: Optional[float] = Query( - None, description="Minimum value for the energy above the hull in eV/atom.", - ), - equillibrium_reaction_energy_per_atom_max: Optional[float] = Query( - None, - description="Maximum value for the equilibrium reaction energy in eV/atom.", - ), - equillibrium_reaction_energy_per_atom_min: Optional[float] = Query( - None, - description="Minimum value for the equilibrium reaction energy in eV/atom.", - ), - uncorrected_energy_per_atom_max: Optional[float] = Query( - None, description="Maximum value for the uncorrected total energy in eV.", - ), - uncorrected_energy_per_atom_min: Optional[float] = Query( - None, description="Minimum value for the uncorrected total energy in eV.", - ), - ) -> STORE_PARAMS: - - crit = defaultdict(dict) # type: dict - - d = { - "energy_per_atom": [energy_per_atom_min, energy_per_atom_max], - "formation_energy_per_atom": [ - formation_energy_per_atom_min, - formation_energy_per_atom_max, - ], - "energy_above_hull": [energy_above_hull_min, energy_above_hull_max], - "equillibrium_reaction_energy_per_atom": [ - equillibrium_reaction_energy_per_atom_min, - equillibrium_reaction_energy_per_atom_max, - ], - "uncorrected_energy": [ - uncorrected_energy_per_atom_min, - uncorrected_energy_per_atom_max, - ], - } - - for entry in d: - if d[entry][0] is not None: - crit[entry]["$gte"] = d[entry][0] - - if d[entry][1] is not None: - crit[entry]["$lte"] = d[entry][1] - - return {"criteria": crit} - - def ensure_indexes(self): - keys = self._keys_from_query() - indexes = [] - for key in keys: - if "_min" in key: - key = key.replace("_min", "") - indexes.append((key, False)) - return indexes diff --git a/src/mp_api/routes/thermo/resources.py b/src/mp_api/routes/thermo/resources.py index e867da5c..5cb83bac 100644 --- a/src/mp_api/routes/thermo/resources.py +++ b/src/mp_api/routes/thermo/resources.py @@ -1,31 +1,31 @@ -from mp_api.core.resource import GetResource +from maggma.api.query_operator.dynamic import NumericQuery +from maggma.api.resource import ReadOnlyResource from emmet.core.thermo import ThermoDoc -from mp_api.core.query_operator import ( +from maggma.api.query_operator import ( PaginationQuery, SortQuery, SparseFieldsQuery, - VersionQuery, ) from mp_api.routes.thermo.query_operators import ( ThermoChemicalQuery, - ThermoEnergyQuery, IsStableQuery, ) +from mp_api.core.settings import MAPISettings + from mp_api.routes.materials.query_operators import MultiMaterialIDQuery def thermo_resource(thermo_store): - resource = GetResource( + resource = ReadOnlyResource( thermo_store, ThermoDoc, query_operators=[ - VersionQuery(), MultiMaterialIDQuery(), ThermoChemicalQuery(), IsStableQuery(), - ThermoEnergyQuery(), + NumericQuery(model=ThermoDoc), SortQuery(), PaginationQuery(), SparseFieldsQuery( diff --git a/src/mp_api/routes/wulff/resources.py b/src/mp_api/routes/wulff/resources.py index 30cc7999..eff4774d 100644 --- a/src/mp_api/routes/wulff/resources.py +++ b/src/mp_api/routes/wulff/resources.py @@ -1,17 +1,14 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from mp_api.routes.wulff.models import WulffDoc -from mp_api.core.query_operator import PaginationQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SparseFieldsQuery def wulff_resource(wulff_store): - resource = GetResource( + resource = ReadOnlyResource( wulff_store, WulffDoc, - query_operators=[ - PaginationQuery(), - SparseFieldsQuery(WulffDoc, default_fields=["task_id"]), - ], + query_operators=[PaginationQuery(), SparseFieldsQuery(WulffDoc, default_fields=["task_id"]),], tags=["Surface Properties"], enable_default_search=False, ) diff --git a/src/mp_api/routes/xas/query_operator.py b/src/mp_api/routes/xas/query_operator.py index 66ac0499..ebac3c1d 100644 --- a/src/mp_api/routes/xas/query_operator.py +++ b/src/mp_api/routes/xas/query_operator.py @@ -1,5 +1,6 @@ from fastapi import Query -from mp_api.core.query_operator import STORE_PARAMS, QueryOperator +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS from emmet.core.xas import Edge, Type from pymatgen.core.periodic_table import Element from typing import Optional @@ -35,10 +36,7 @@ class XASTaskIDQuery(QueryOperator): """ def query( - self, - task_ids: Optional[str] = Query( - None, description="Comma-separated list of task_ids to query on" - ), + self, task_ids: Optional[str] = Query(None, description="Comma-separated list of task_ids to query on"), ) -> STORE_PARAMS: crit = {} diff --git a/src/mp_api/routes/xas/resources.py b/src/mp_api/routes/xas/resources.py index 8665406c..15989756 100644 --- a/src/mp_api/routes/xas/resources.py +++ b/src/mp_api/routes/xas/resources.py @@ -1,13 +1,13 @@ -from mp_api.core.resource import GetResource +from maggma.api.resource import ReadOnlyResource from emmet.core.xas import XASDoc -from mp_api.core.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.materials.query_operators import ElementsQuery, FormulaQuery from mp_api.routes.xas.query_operator import XASQuery, XASTaskIDQuery def xas_resource(xas_store): - resource = GetResource( + resource = ReadOnlyResource( xas_store, XASDoc, query_operators=[