From 113557a7a2efa115f3b7644dac40df1274fcac10 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Wed, 10 Nov 2021 11:49:42 -0800 Subject: [PATCH] Dielectric, Piezoelectric, Magnetism, Bonds, and Phase Diagrams (#420) * Emmet bump * Update dielectric, piezoelectric, magnetism * task_id to material_id in clients * Fix collection names * Remove sub doc fields from client tests * Ignore some endpoints in tests temp * Temp xfail on endpoints * Bump emmet * Add bonds endpoint * Bump emmet * Phase diagram endpoint added to thermo * Temp xfail on thermo pd client method * Linting * Alter materials and thermo collection names * Fix tests * Linting * Fix robocrys test * Update robocrys test * Update database names * Fix missing fstrings * Remove stray print * Temp xfail on find structure test --- app.py | 73 +++++++++--- requirements.txt | 2 +- src/mp_api/routes/bonds/__init__.py | 0 src/mp_api/routes/bonds/client.py | 105 ++++++++++++++++++ src/mp_api/routes/bonds/query_operators.py | 102 +++++++++++++++++ src/mp_api/routes/bonds/resources.py | 26 +++++ src/mp_api/routes/dielectric/client.py | 4 +- src/mp_api/routes/dielectric/models.py | 31 ------ .../routes/dielectric/query_operators.py | 16 +-- src/mp_api/routes/dielectric/resources.py | 4 +- src/mp_api/routes/magnetism/client.py | 4 +- src/mp_api/routes/magnetism/models.py | 78 ------------- .../routes/magnetism/query_operators.py | 20 ++-- src/mp_api/routes/magnetism/resources.py | 6 +- src/mp_api/routes/piezo/client.py | 8 +- src/mp_api/routes/piezo/models.py | 62 ----------- src/mp_api/routes/piezo/query_operators.py | 4 +- src/mp_api/routes/piezo/resources.py | 8 +- src/mp_api/routes/thermo/client.py | 22 +++- src/mp_api/routes/thermo/resources.py | 15 ++- tests/bonds/__init__.py | 0 tests/bonds/test_client.py | 85 ++++++++++++++ tests/bonds/test_query_operators.py | 68 ++++++++++++ tests/dielectric/test_client.py | 3 +- tests/dielectric/test_query_operators.py | 8 +- tests/electrodes/test_client.py | 1 - tests/magnetism/test_client.py | 3 +- tests/magnetism/test_query_operators.py | 14 +-- tests/piezo/test_client.py | 3 +- tests/piezo/test_query_operators.py | 4 +- tests/robocrys/test_client.py | 5 +- tests/test_client.py | 8 +- tests/test_mprester.py | 12 +- tests/thermo/test_client.py | 9 ++ 34 files changed, 559 insertions(+), 254 deletions(-) create mode 100644 src/mp_api/routes/bonds/__init__.py create mode 100644 src/mp_api/routes/bonds/client.py create mode 100644 src/mp_api/routes/bonds/query_operators.py create mode 100644 src/mp_api/routes/bonds/resources.py delete mode 100644 src/mp_api/routes/dielectric/models.py delete mode 100644 src/mp_api/routes/magnetism/models.py delete mode 100644 src/mp_api/routes/piezo/models.py create mode 100644 tests/bonds/__init__.py create mode 100644 tests/bonds/test_client.py create mode 100644 tests/bonds/test_query_operators.py diff --git a/app.py b/app.py index e7933d5c..62b150eb 100644 --- a/app.py +++ b/app.py @@ -14,13 +14,18 @@ debug = default_settings.DEBUG materials_store_json = os.environ.get("MATERIALS_STORE", "materials_store.json") +bonds_store_json = os.environ.get("BONDS_STORE", "bonds_store.json") formula_autocomplete_store_json = os.environ.get( "FORMULA_AUTOCOMPLETE_STORE", "formula_autocomplete_store.json" ) task_store_json = os.environ.get("TASK_STORE", "task_store.json") thermo_store_json = os.environ.get("THERMO_STORE", "thermo_store.json") -dielectric_piezo_store_json = os.environ.get( - "DIELECTRIC_PIEZO_STORE", "dielectric_piezo_store.json" +phase_diagram_store_json = os.environ.get( + "PHASE_DIAGRAM_STORE", "phase_diagram_store.json" +) +dielectric_store_json = os.environ.get("DIELECTRIC_STORE", "dielectric_store.json") +piezoelectric_store_json = os.environ.get( + "PIEZOELECTRIC_STORE", "piezoelectric_store.json" ) magnetism_store_json = os.environ.get("MAGNETISM_STORE", "magnetism_store.json") phonon_bs_store_json = os.environ.get("PHONON_BS_STORE", "phonon_bs_store.json") @@ -75,7 +80,14 @@ uri=f"mongodb+srv://{db_uri}", database=f"mp_core_{db_suffix}", key="material_id", - collection_name=f"materials.core_{db_version}", + collection_name="materials", + ) + + bonds_store = MongoURIStore( + uri=f"mongodb+srv://{db_uri}", + database=f"mp_core_{db_suffix}", + key="material_id", + collection_name="bonds", ) formula_autocomplete_store = MongoURIStore( @@ -96,20 +108,34 @@ uri=f"mongodb+srv://{db_uri}", database=f"mp_core_{db_suffix}", key="material_id", - collection_name=f"thermo_{db_version}", + collection_name="thermo", ) - dielectric_piezo_store = MongoURIStore( + phase_diagram_store = MongoURIStore( uri=f"mongodb+srv://{db_uri}", database=f"mp_core_{db_suffix}", - key="task_id", + key="chemsys", + collection_name="phase_diagram", + ) + + dielectric_store = MongoURIStore( + uri=f"mongodb+srv://{db_uri}", + database=f"mp_core_{db_suffix}", + key="material_id", collection_name="dielectric", ) + piezoelectric_store = MongoURIStore( + uri=f"mongodb+srv://{db_uri}", + database=f"mp_core_{db_suffix}", + key="material_id", + collection_name="piezoelectric", + ) + magnetism_store = MongoURIStore( uri=f"mongodb+srv://{db_uri}", database=f"mp_core_{db_suffix}", - key="task_id", + key="material_id", collection_name="magnetism", ) @@ -308,10 +334,13 @@ else: materials_store = loadfn(materials_store_json) + bonds_store = loadfn(bonds_store_json) formula_autocomplete_store = loadfn(formula_autocomplete_store_json) task_store = loadfn(task_store_json) thermo_store = loadfn(thermo_store_json) - dielectric_piezo_store = loadfn(dielectric_piezo_store_json) + phase_diagram_store = loadfn(phase_diagram_store_json) + dielectric_store = loadfn(dielectric_store_json) + piezoelectric_store = loadfn(piezoelectric_store_json) magnetism_store = loadfn(magnetism_store_json) phonon_bs_store = loadfn(phonon_bs_store_json) eos_store = loadfn(eos_store_json) @@ -362,7 +391,10 @@ } ) -# resources.update({"find_structure": find_structure_resource(materials_store)}) +# Bonds +from mp_api.routes.bonds.resources import bonds_resource + +resources.update({"bonds": [bonds_resource(bonds_store)]}) # Tasks from mp_api.routes.tasks.resources import ( @@ -382,25 +414,32 @@ ) # Thermo -from mp_api.routes.thermo.resources import thermo_resource +from mp_api.routes.thermo.resources import phase_diagram_resource, thermo_resource -resources.update({"thermo": [thermo_resource(thermo_store)]}) +resources.update( + { + "thermo": [ + phase_diagram_resource(phase_diagram_store), + thermo_resource(thermo_store), + ] + } +) # Dielectric from mp_api.routes.dielectric.resources import dielectric_resource -resources.update({"dielectric": [dielectric_resource(dielectric_piezo_store)]}) +resources.update({"dielectric": [dielectric_resource(dielectric_store)]}) + +# Piezoelectric +from mp_api.routes.piezo.resources import piezo_resource + +resources.update({"piezoelectric": [piezo_resource(piezoelectric_store)]}) # Magnetism from mp_api.routes.magnetism.resources import magnetism_resource resources.update({"magnetism": [magnetism_resource(magnetism_store)]}) -# Piezoelectric -from mp_api.routes.piezo.resources import piezo_resource - -resources.update({"piezoelectric": [piezo_resource(dielectric_piezo_store)]}) - # Phonon from mp_api.routes.phonon.resources import phonon_bsdos_resource diff --git a/requirements.txt b/requirements.txt index 21bed449..eaa24ccf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,6 @@ typing-extensions==3.10.0.2 maggma==0.31.0 requests==2.26.0 monty==2021.8.17 -emmet-core==0.15.11 +emmet-core==0.15.7 ratelimit==2.2.1 mpcontribs-client>=3.14.3 diff --git a/src/mp_api/routes/bonds/__init__.py b/src/mp_api/routes/bonds/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/mp_api/routes/bonds/client.py b/src/mp_api/routes/bonds/client.py new file mode 100644 index 00000000..0e18b106 --- /dev/null +++ b/src/mp_api/routes/bonds/client.py @@ -0,0 +1,105 @@ +from typing import List, Optional, Tuple +from collections import defaultdict + +from mp_api.core.client import BaseRester +from emmet.core.bonds import BondingDoc + + +class BondsRester(BaseRester[BondingDoc]): + + suffix = "bonds" + document_model = BondingDoc # type: ignore + primary_key = "material_id" + + def search_bonds_docs( + self, + max_bond_length: Optional[Tuple[float, float]] = None, + min_bond_length: Optional[Tuple[float, float]] = None, + mean_bond_length: Optional[Tuple[float, float]] = None, + coordination_envs: Optional[List[str]] = None, + coordination_envs_anonymous: Optional[List[str]] = None, + sort_field: Optional[str] = None, + ascending: Optional[bool] = None, + num_chunks: Optional[int] = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: Optional[List[str]] = None, + ): + """ + Query dielectric docs using a variety of search criteria. + + Arguments: + max_bond_length (Tuple[float,float]): Minimum and maximum value for the maximum bond length + in the structure to consider. + min_bond_length (Tuple[float,float]): Minimum and maximum value for the minimum bond length + in the structure to consider. + mean_bond_length (Tuple[float,float]): Minimum and maximum value for the mean bond length + in the structure to consider. + coordination_envs (List[str]): List of coordination environments to consider (e.g. ['Mo-S(6)', 'S-Mo(3)']). + coordination_envs_anonymous (List[str]): List of anonymous coordination environments to consider + (e.g. ['A-B(6)', 'A-B(3)']). + sort_field (str): Field used to sort results. + ascending (bool): Whether sorting should be in ascending order. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields in DielectricDoc to return data for. + Default is material_id and last_updated if all_fields is False. + + Returns: + ([BondingDoc]) List of bonding documents. + """ + + query_params = defaultdict(dict) # type: dict + + if max_bond_length: + query_params.update( + { + "max_bond_length_min": max_bond_length[0], + "max_bond_length_max": max_bond_length[1], + } + ) + + if min_bond_length: + query_params.update( + { + "min_bond_length_min": min_bond_length[0], + "min_bond_length_max": min_bond_length[1], + } + ) + + if mean_bond_length: + query_params.update( + { + "mean_bond_length_min": mean_bond_length[0], + "mean_bond_length_max": mean_bond_length[1], + } + ) + + if coordination_envs is not None: + query_params.update({"coordination_envs": ",".join(coordination_envs)}) + + if coordination_envs_anonymous is not None: + query_params.update( + {"coordination_envs_anonymous": ",".join(coordination_envs_anonymous)} + ) + + if sort_field: + query_params.update({"sort_field": sort_field}) + + if ascending is not None: + query_params.update({"ascending": ascending}) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } + + return super().search( + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params + ) diff --git a/src/mp_api/routes/bonds/query_operators.py b/src/mp_api/routes/bonds/query_operators.py new file mode 100644 index 00000000..a92d0ee9 --- /dev/null +++ b/src/mp_api/routes/bonds/query_operators.py @@ -0,0 +1,102 @@ +from typing import Optional +from fastapi import Query +from maggma.api.query_operator import QueryOperator +from maggma.api.utils import STORE_PARAMS + +from collections import defaultdict + + +class BondLengthQuery(QueryOperator): + """ + Method to generate a query on bond length data. + """ + + def query( + self, + max_bond_length_max: Optional[float] = Query( + None, + description="Maximum value for the maximum bond length in the structure.", + ), + max_bond_length_min: Optional[float] = Query( + None, + description="Minimum value for the maximum bond length in the structure.", + ), + min_bond_length_max: Optional[float] = Query( + None, + description="Maximum value for the minimum bond length in the structure.", + ), + min_bond_length_min: Optional[float] = Query( + None, + description="Minimum value for the minimum bond length in the structure.", + ), + mean_bond_length_max: Optional[float] = Query( + None, + description="Maximum value for the mean bond length in the structure.", + ), + mean_bond_length_min: Optional[float] = Query( + None, + description="Minimum value for the mean bond length in the structure.", + ), + ) -> STORE_PARAMS: + + crit = defaultdict(dict) # type: dict + + d = { + "bond_length_stats.max": [max_bond_length_min, max_bond_length_max], + "bond_length_stats.min": [min_bond_length_min, min_bond_length_max], + "bond_length_stats.mean": [mean_bond_length_min, mean_bond_length_max], + } + + for entry in d: + if d[entry][0] is not None: + crit[entry]["$gte"] = d[entry][0] + + if d[entry][1] is not None: + crit[entry]["$lte"] = d[entry][1] + + return {"criteria": crit} + + def ensure_indexes(self): # pragma: no cover + keys = [ + "bond_length_stats.max", + "bond_length_stats.min", + "bond_length_stats.mean", + ] + return [(key, False) for key in keys] + + +class CoordinationEnvsQuery(QueryOperator): + """ + Method to generate a query on coordination environment data. + """ + + def query( + self, + coordination_envs: Optional[str] = Query( + None, + description="Query by coordination environments in the material composition as a comma-separated list\ + (e.g. 'Mo-S(6),S-Mo(3)')", + ), + coordination_envs_anonymous: Optional[str] = Query( + None, + description="Query by anonymous coordination environments in the material composition as a comma-separated\ + list (e.g. 'A-B(6),A-B(3)')", + ), + ) -> STORE_PARAMS: + + crit = {} # type: dict + + if coordination_envs: + env_list = [env.strip() for env in coordination_envs.split(",")] + crit["coordination_envs"] = {"$all": [str(env) for env in env_list]} + + if coordination_envs_anonymous: + env_list = [env.strip() for env in coordination_envs_anonymous.split(",")] + crit["coordination_envs_anonymous"] = { + "$all": [str(env) for env in env_list] + } + + return {"criteria": crit} + + def ensure_indexes(self): # pragma: no cover + return [("coordination_envs", False), ("coordination_envs_anonymous", False)] diff --git a/src/mp_api/routes/bonds/resources.py b/src/mp_api/routes/bonds/resources.py new file mode 100644 index 00000000..f41d4472 --- /dev/null +++ b/src/mp_api/routes/bonds/resources.py @@ -0,0 +1,26 @@ +from maggma.api.resource import ReadOnlyResource +from emmet.core.bonds import BondingDoc + +from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery + +from mp_api.routes.bonds.query_operators import BondLengthQuery, CoordinationEnvsQuery + + +def bonds_resource(bonds_store): + resource = ReadOnlyResource( + bonds_store, + BondingDoc, + query_operators=[ + BondLengthQuery(), + CoordinationEnvsQuery(), + SortQuery(), + PaginationQuery(), + SparseFieldsQuery( + BondingDoc, default_fields=["material_id", "last_updated"], + ), + ], + tags=["Bonds"], + disable_validation=True, + ) + + return resource diff --git a/src/mp_api/routes/dielectric/client.py b/src/mp_api/routes/dielectric/client.py index f5654bd0..f4efbf33 100644 --- a/src/mp_api/routes/dielectric/client.py +++ b/src/mp_api/routes/dielectric/client.py @@ -2,14 +2,14 @@ from collections import defaultdict from mp_api.core.client import BaseRester -from mp_api.routes.dielectric.models import DielectricDoc +from emmet.core.polar import DielectricDoc class DielectricRester(BaseRester[DielectricDoc]): suffix = "dielectric" document_model = DielectricDoc # type: ignore - primary_key = "task_id" + primary_key = "material_id" def search_dielectric_docs( self, diff --git a/src/mp_api/routes/dielectric/models.py b/src/mp_api/routes/dielectric/models.py deleted file mode 100644 index a15c2930..00000000 --- a/src/mp_api/routes/dielectric/models.py +++ /dev/null @@ -1,31 +0,0 @@ -from datetime import datetime -from monty.json import MontyDecoder - -from emmet.core.polar import DielectricDoc as BaseDielectricDoc - -from pydantic import BaseModel, Field, validator - - -class DielectricDoc(BaseModel): - """ - Dielectric tensor and associated information. - """ - - dielectric: BaseDielectricDoc = Field( - None, description="Dielectric data", - ) - - task_id: str = Field( - None, - description="The Materials Project ID of the material. This comes in the form: mp-******", - ) - - last_updated: datetime = Field( - None, - description="Timestamp for the most recent calculation for this Material document", - ) - - # Make sure that the datetime field is properly formatted - @validator("last_updated", pre=True) - def last_updated_dict_ok(cls, v): - return MontyDecoder().process_decoded(v) diff --git a/src/mp_api/routes/dielectric/query_operators.py b/src/mp_api/routes/dielectric/query_operators.py index ffb484df..564f3ba9 100644 --- a/src/mp_api/routes/dielectric/query_operators.py +++ b/src/mp_api/routes/dielectric/query_operators.py @@ -42,10 +42,10 @@ def query( crit = defaultdict(dict) # type: dict d = { - "dielectric.e_total": [e_total_min, e_total_max], - "dielectric.e_ionic": [e_ionic_min, e_ionic_max], - "dielectric.e_electronic": [e_electronic_min, e_electronic_max], - "dielectric.n": [n_min, n_max], + "e_total": [e_total_min, e_total_max], + "e_ionic": [e_ionic_min, e_ionic_max], + "e_electronic": [e_electronic_min, e_electronic_max], + "n": [n_min, n_max], } for entry in d: @@ -59,9 +59,9 @@ def query( def ensure_indexes(self): # pragma: no cover keys = [ - "dielectric.e_total", - "dielectric.e_ionic", - "dielectric.e_electronic", - "dielectric.n", + "e_total", + "e_ionic", + "e_electronic", + "n", ] return [(key, False) for key in keys] diff --git a/src/mp_api/routes/dielectric/resources.py b/src/mp_api/routes/dielectric/resources.py index 976381e4..b735d268 100644 --- a/src/mp_api/routes/dielectric/resources.py +++ b/src/mp_api/routes/dielectric/resources.py @@ -1,5 +1,5 @@ from maggma.api.resource import ReadOnlyResource -from mp_api.routes.dielectric.models import DielectricDoc +from emmet.core.polar import DielectricDoc from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.dielectric.query_operators import DielectricQuery @@ -14,7 +14,7 @@ def dielectric_resource(dielectric_store): SortQuery(), PaginationQuery(), SparseFieldsQuery( - DielectricDoc, default_fields=["task_id", "last_updated"] + DielectricDoc, default_fields=["material_id", "last_updated"] ), ], tags=["Dielectric"], diff --git a/src/mp_api/routes/magnetism/client.py b/src/mp_api/routes/magnetism/client.py index a4103770..65793d9d 100644 --- a/src/mp_api/routes/magnetism/client.py +++ b/src/mp_api/routes/magnetism/client.py @@ -2,7 +2,7 @@ from collections import defaultdict from mp_api.core.client import BaseRester -from mp_api.routes.magnetism.models import MagnetismDoc +from emmet.core.magnetism import MagnetismDoc from pymatgen.analysis.magnetism import Ordering @@ -11,7 +11,7 @@ class MagnetismRester(BaseRester[MagnetismDoc]): suffix = "magnetism" document_model = MagnetismDoc # type: ignore - primary_key = "task_id" + primary_key = "material_id" def search_magnetism_docs( self, diff --git a/src/mp_api/routes/magnetism/models.py b/src/mp_api/routes/magnetism/models.py deleted file mode 100644 index 78774eb0..00000000 --- a/src/mp_api/routes/magnetism/models.py +++ /dev/null @@ -1,78 +0,0 @@ -from enum import Enum -from typing import List - -from pydantic import BaseModel, Field, validator -from monty.json import MontyDecoder -from datetime import datetime - - -class MagnetismData(BaseModel): - """ - Model for magnetic data within a magnetism doc - """ - - ordering: str = Field( - None, description="Magnetic ordering.", - ) - - is_magnetic: bool = Field( - None, description="Whether the material is magnetic.", - ) - - exchange_symmetry: int = Field( - None, description="Exchange symmetry.", - ) - - num_magnetic_sites: int = Field( - None, description="The number of magnetic sites.", - ) - - num_unique_magnetic_sites: int = Field( - None, description="The number of unique magnetic sites.", - ) - - types_of_magnetic_species: List[str] = Field( - None, description="Magnetic specie elements.", - ) - - magmoms: List[float] = Field( - None, description="Magnetic moments for each site.", - ) - - total_magnetization: float = Field( - None, description="Total magnetization in μB.", - ) - - total_magnetization_normalized_vol: float = Field( - None, description="Total magnetization normalized by volume in μB/ų.", - ) - - total_magnetization_normalized_formula_units: float = Field( - None, description="Total magnetization normalized by formula unit in μB/f.u. .", - ) - - -class MagnetismDoc(BaseModel): - """ - Magnetic ordering, total magnetizaiton, ... - """ - - task_id: str = Field( - None, - description="The ID of this material, used as a universal reference across property documents." - "This comes in the form: mp-******", - ) - - magnetism: MagnetismData = Field( - None, description="Magnetic data for the material", - ) - - last_updated: datetime = Field( - None, - description="Timestamp for the most recent calculation for this Material document", - ) - - # Make sure that the datetime field is properly formatted - @validator("last_updated", pre=True) - def last_updated_dict_ok(cls, v): - return MontyDecoder().process_decoded(v) diff --git a/src/mp_api/routes/magnetism/query_operators.py b/src/mp_api/routes/magnetism/query_operators.py index c5bcf2d0..e6193ead 100644 --- a/src/mp_api/routes/magnetism/query_operators.py +++ b/src/mp_api/routes/magnetism/query_operators.py @@ -59,23 +59,17 @@ def query( crit = defaultdict(dict) # type: dict d = { - "magnetism.total_magnetization": [ - total_magnetization_min, - total_magnetization_max, - ], - "magnetism.total_magnetization_normalized_vol": [ + "total_magnetization": [total_magnetization_min, total_magnetization_max], + "total_magnetization_normalized_vol": [ total_magnetization_normalized_vol_min, total_magnetization_normalized_vol_max, ], - "magnetism.total_magnetization_normalized_formula_units": [ + "total_magnetization_normalized_formula_units": [ total_magnetization_normalized_formula_units_min, total_magnetization_normalized_formula_units_max, ], - "magnetism.num_magnetic_sites": [ - num_magnetic_sites_min, - num_magnetic_sites_max, - ], - "magnetism.num_unique_magnetic_sites": [ + "num_magnetic_sites": [num_magnetic_sites_min, num_magnetic_sites_max], + "num_unique_magnetic_sites": [ num_unique_magnetic_sites_min, num_unique_magnetic_sites_max, ], @@ -89,7 +83,7 @@ def query( crit[entry]["$lte"] = d[entry][1] if ordering: - crit["magnetism.ordering"] = ordering.value + crit["ordering"] = ordering.value return {"criteria": crit} @@ -99,5 +93,5 @@ def ensure_indexes(self): # pragma: no cover for key in keys: if "_min" in key: key = key.replace("_min", "") - indexes.append(("magnetism." + key, False)) + indexes.append((key, False)) return indexes diff --git a/src/mp_api/routes/magnetism/resources.py b/src/mp_api/routes/magnetism/resources.py index 49205c24..d179823e 100644 --- a/src/mp_api/routes/magnetism/resources.py +++ b/src/mp_api/routes/magnetism/resources.py @@ -1,5 +1,5 @@ from maggma.api.resource import ReadOnlyResource -from mp_api.routes.magnetism.models import MagnetismDoc +from emmet.core.magnetism import MagnetismDoc from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.magnetism.query_operators import MagneticQuery @@ -13,7 +13,9 @@ def magnetism_resource(magnetism_store): MagneticQuery(), SortQuery(), PaginationQuery(), - SparseFieldsQuery(MagnetismDoc, default_fields=["task_id", "last_updated"]), + SparseFieldsQuery( + MagnetismDoc, default_fields=["material_id", "last_updated"] + ), ], tags=["Magnetism"], disable_validation=True, diff --git a/src/mp_api/routes/piezo/client.py b/src/mp_api/routes/piezo/client.py index 8fcbb89f..261439f3 100644 --- a/src/mp_api/routes/piezo/client.py +++ b/src/mp_api/routes/piezo/client.py @@ -2,16 +2,16 @@ from collections import defaultdict from mp_api.core.client import BaseRester -from mp_api.routes.piezo.models import PiezoDoc +from emmet.core.polar import PiezoelectricDoc import warnings -class PiezoRester(BaseRester[PiezoDoc]): +class PiezoRester(BaseRester[PiezoelectricDoc]): suffix = "piezoelectric" - document_model = PiezoDoc # type: ignore - primary_key = "task_id" + document_model = PiezoelectricDoc # type: ignore + primary_key = "material_id" def search_piezoelectric_docs( self, diff --git a/src/mp_api/routes/piezo/models.py b/src/mp_api/routes/piezo/models.py deleted file mode 100644 index 25e5fb70..00000000 --- a/src/mp_api/routes/piezo/models.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import List -from datetime import datetime -from monty.json import MontyDecoder - -from pydantic import BaseModel, Field, validator - - -class PiezoData(BaseModel): - """ - Piezoelectric tensor and associated information. - """ - - total: List[List[float]] = Field( - None, - description="Total piezoelectric tensor in C/m²", - ) - - ionic: List[List[float]] = Field( - None, - description="Ionic contribution to piezoelectric tensor in C/m²", - ) - - static: List[List[float]] = Field( - None, - description="Electronic contribution to piezoelectric tensor in C/m²", - ) - - e_ij_max: float = Field( - None, - description="Piezoelectric modulus", - ) - - max_direction: List[float] = Field( - None, - description="Crystallographic direction", - ) - - -class PiezoDoc(BaseModel): - """ - Model for a document containing piezoelectric data - """ - - piezo: PiezoData = Field( - None, - description="Piezoelectric data", - ) - - task_id: str = Field( - None, - description="The Materials Project ID of the material. This comes in the form: mp-******", - ) - - last_updated: datetime = Field( - None, - description="Timestamp for the most recent calculation for this Material document", - ) - - # Make sure that the datetime field is properly formatted - @validator("last_updated", pre=True) - def last_updated_dict_ok(cls, v): - return MontyDecoder().process_decoded(v) diff --git a/src/mp_api/routes/piezo/query_operators.py b/src/mp_api/routes/piezo/query_operators.py index 3ac51a6b..1a84d423 100644 --- a/src/mp_api/routes/piezo/query_operators.py +++ b/src/mp_api/routes/piezo/query_operators.py @@ -24,7 +24,7 @@ def query( crit = defaultdict(dict) # type: dict d = { - "piezo.e_ij_max": [piezo_modulus_min, piezo_modulus_max], + "e_ij_max": [piezo_modulus_min, piezo_modulus_max], } for entry in d: @@ -37,4 +37,4 @@ def query( return {"criteria": crit} def ensure_indexes(self): # pragma: no cover - return [("piezo.e_ij_max", False)] + return [("e_ij_max", False)] diff --git a/src/mp_api/routes/piezo/resources.py b/src/mp_api/routes/piezo/resources.py index 110603ea..9755cb3b 100644 --- a/src/mp_api/routes/piezo/resources.py +++ b/src/mp_api/routes/piezo/resources.py @@ -1,5 +1,5 @@ from maggma.api.resource import ReadOnlyResource -from mp_api.routes.piezo.models import PiezoDoc +from emmet.core.polar import PiezoelectricDoc from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.piezo.query_operators import PiezoelectricQuery @@ -8,12 +8,14 @@ def piezo_resource(piezo_store): resource = ReadOnlyResource( piezo_store, - PiezoDoc, + PiezoelectricDoc, query_operators=[ PiezoelectricQuery(), SortQuery(), PaginationQuery(), - SparseFieldsQuery(PiezoDoc, default_fields=["task_id", "last_updated"]), + SparseFieldsQuery( + PiezoelectricDoc, default_fields=["material_id", "last_updated"] + ), ], tags=["Piezoelectric"], disable_validation=True, diff --git a/src/mp_api/routes/thermo/client.py b/src/mp_api/routes/thermo/client.py index d4bb1b9a..fb936d40 100644 --- a/src/mp_api/routes/thermo/client.py +++ b/src/mp_api/routes/thermo/client.py @@ -1,8 +1,8 @@ from collections import defaultdict from typing import Optional, List, Tuple -from mp_api.core.client import BaseRester +from mp_api.core.client import BaseRester, MPRestError from emmet.core.thermo import ThermoDoc -from pymatgen.core.periodic_table import Element +from pymatgen.analysis.phase_diagram import PhaseDiagram class ThermoRester(BaseRester[ThermoDoc]): @@ -111,3 +111,21 @@ def search_thermo_docs( fields=fields, **query_params, ) + + def get_phase_diagram_from_chemsys(self, chemsys: str) -> PhaseDiagram: + """ + Get a pre-computed phase diagram for a given chemsys. + + Arguments: + material_id (str): Materials project ID + Returns: + phase_diagram (PhaseDiagram): Pymatgen phase diagram object. + """ + + response = self._query_resource( + fields=["phase_diagram"], + suburl=f"phase_diagram/{chemsys}", + use_document_model=False, + ).get("data") + + return response[0]["phase_diagram"] # type: ignore diff --git a/src/mp_api/routes/thermo/resources.py b/src/mp_api/routes/thermo/resources.py index bce7bb70..0062ad3f 100644 --- a/src/mp_api/routes/thermo/resources.py +++ b/src/mp_api/routes/thermo/resources.py @@ -1,6 +1,7 @@ from maggma.api.query_operator.dynamic import NumericQuery from maggma.api.resource import ReadOnlyResource from emmet.core.thermo import ThermoDoc +from emmet.core.thermo import PhaseDiagramDoc from maggma.api.query_operator import ( PaginationQuery, @@ -8,10 +9,22 @@ SparseFieldsQuery, ) from mp_api.routes.thermo.query_operators import IsStableQuery - from mp_api.routes.materials.query_operators import MultiMaterialIDQuery, FormulaQuery +def phase_diagram_resource(phase_diagram_store): + resource = ReadOnlyResource( + phase_diagram_store, + PhaseDiagramDoc, + tags=["Thermo"], + sub_path="/phase_diagram/", + disable_validation=True, + enable_default_search=False, + ) + + return resource + + def thermo_resource(thermo_store): resource = ReadOnlyResource( thermo_store, diff --git a/tests/bonds/__init__.py b/tests/bonds/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bonds/test_client.py b/tests/bonds/test_client.py new file mode 100644 index 00000000..c1b4fe81 --- /dev/null +++ b/tests/bonds/test_client.py @@ -0,0 +1,85 @@ +import os +import pytest +from mp_api.routes.bonds.client import BondsRester + +import inspect +import typing + +resters = [BondsRester()] + +excluded_params = [ + "sort_field", + "ascending", + "chunk_size", + "num_chunks", + "all_fields", + "fields", +] + +sub_doc_fields = [] # type: list + +alt_name_dict = {} # type: dict + +custom_field_tests = {} # type: dict + + +@pytest.mark.xfail(reason="Needs deployment") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) +@pytest.mark.parametrize("rester", resters) +def test_client(rester): + # Get specific search method + search_method = None + for entry in inspect.getmembers(rester, predicate=inspect.ismethod): + if "search" in entry[0] and entry[0] != "search": + search_method = entry[1] + + if search_method is not None: + # Get list of parameters + param_tuples = list(typing.get_type_hints(search_method).items()) + + # Query API for each numeric and bollean parameter and check if returned + for entry in param_tuples: + param = entry[0] + if param not in excluded_params: + param_type = entry[1].__args__[0] + q = None + if param_type is typing.Tuple[int, int]: + project_field = alt_name_dict.get(param, None) + q = { + param: (-100, 100), + "chunk_size": 1, + "num_chunks": 1, + } + elif param_type is typing.Tuple[float, float]: + project_field = alt_name_dict.get(param, None) + q = { + param: (0, 100.12), + "chunk_size": 1, + "num_chunks": 1, + } + elif param_type is bool: + project_field = alt_name_dict.get(param, None) + q = { + param: False, + "chunk_size": 1, + "num_chunks": 1, + } + elif param in custom_field_tests: + project_field = alt_name_dict.get(param, None) + q = { + param: custom_field_tests[param], + "chunk_size": 1, + "num_chunks": 1, + } + + doc = search_method(**q)[0].dict() + for sub_field in sub_doc_fields: + if sub_field in doc: + doc = doc[sub_field] + + assert ( + doc[project_field if project_field is not None else param] + is not None + ) diff --git a/tests/bonds/test_query_operators.py b/tests/bonds/test_query_operators.py new file mode 100644 index 00000000..bc9fdf3d --- /dev/null +++ b/tests/bonds/test_query_operators.py @@ -0,0 +1,68 @@ +from mp_api.routes.bonds.query_operators import BondLengthQuery, CoordinationEnvsQuery + +from monty.tempfile import ScratchDir +from monty.serialization import loadfn, dumpfn + + +def test_bond_length_query_operator(): + op = BondLengthQuery() + + q = op.query( + max_bond_length_min=0, + max_bond_length_max=5, + min_bond_length_min=0, + min_bond_length_max=5, + mean_bond_length_min=0, + mean_bond_length_max=5, + ) + + fields = [ + "bond_length_stats.min", + "bond_length_stats.max", + "bond_length_stats.mean", + ] + + assert q == {"criteria": {field: {"$gte": 0, "$lte": 5} for field in fields}} + + with ScratchDir("."): + dumpfn(op, "temp.json") + new_op = loadfn("temp.json") + q = new_op.query( + max_bond_length_min=0, + max_bond_length_max=5, + min_bond_length_min=0, + min_bond_length_max=5, + mean_bond_length_min=0, + mean_bond_length_max=5, + ) + assert dict(q) == { + "criteria": {field: {"$gte": 0, "$lte": 5} for field in fields} + } + + +def test_coordination_envs_query(): + op = CoordinationEnvsQuery() + + assert op.query( + coordination_envs="Mo-S(6),S-Mo(3)", + coordination_envs_anonymous="A-B(6),A-B(3)", + ) == { + "criteria": { + "coordination_envs": {"$all": ["Mo-S(6)", "S-Mo(3)"]}, + "coordination_envs_anonymous": {"$all": ["A-B(6)", "A-B(3)"]}, + } + } + + with ScratchDir("."): + dumpfn(op, "temp.json") + new_op = loadfn("temp.json") + assert new_op.query( + coordination_envs="Mo-S(6),S-Mo(3)", + coordination_envs_anonymous="A-B(6),A-B(3)", + ) == { + "criteria": { + "coordination_envs": {"$all": ["Mo-S(6)", "S-Mo(3)"]}, + "coordination_envs_anonymous": {"$all": ["A-B(6)", "A-B(3)"]}, + } + } + diff --git a/tests/dielectric/test_client.py b/tests/dielectric/test_client.py index 09bad450..21fe100f 100644 --- a/tests/dielectric/test_client.py +++ b/tests/dielectric/test_client.py @@ -16,13 +16,14 @@ "fields", ] -sub_doc_fields = ["dielectric"] # type: list +sub_doc_fields = [] # type: list alt_name_dict = {"e_static": "e_ionic"} # type: dict custom_field_tests = {} # type: dict +@pytest.mark.xfail(reason="Needs deployment") @pytest.mark.skipif( os.environ.get("MP_API_KEY", None) is None, reason="No API key found." ) diff --git a/tests/dielectric/test_query_operators.py b/tests/dielectric/test_query_operators.py index ef38dfc4..9e64afef 100644 --- a/tests/dielectric/test_query_operators.py +++ b/tests/dielectric/test_query_operators.py @@ -20,10 +20,10 @@ def test_dielectric_query_operator(): ) fields = [ - "dielectric.e_total", - "dielectric.e_ionic", - "dielectric.e_electronic", - "dielectric.n", + "e_total", + "e_ionic", + "e_electronic", + "n", ] assert q == {"criteria": {field: {"$gte": 0, "$lte": 5} for field in fields}} diff --git a/tests/electrodes/test_client.py b/tests/electrodes/test_client.py index fb63b53a..372baddb 100644 --- a/tests/electrodes/test_client.py +++ b/tests/electrodes/test_client.py @@ -24,7 +24,6 @@ custom_field_tests = {"working_ion": Element("Li")} # type: dict -@pytest.mark.xfail # temp until rebuild @pytest.mark.skipif( os.environ.get("MP_API_KEY", None) is None, reason="No API key found." ) diff --git a/tests/magnetism/test_client.py b/tests/magnetism/test_client.py index 12082464..d9bdfee1 100644 --- a/tests/magnetism/test_client.py +++ b/tests/magnetism/test_client.py @@ -17,13 +17,14 @@ "fields", ] -sub_doc_fields = ["magnetism"] # type: list +sub_doc_fields = [] # type: list alt_name_dict = {} # type: dict custom_field_tests = {"ordering": Ordering.FM} # type: dict +@pytest.mark.xfail(reason="Needs deployment") @pytest.mark.skipif( os.environ.get("MP_API_KEY", None) is None, reason="No API key found." ) diff --git a/tests/magnetism/test_query_operators.py b/tests/magnetism/test_query_operators.py index a18c036f..4da73ffa 100644 --- a/tests/magnetism/test_query_operators.py +++ b/tests/magnetism/test_query_operators.py @@ -24,16 +24,16 @@ def test_magnetic_query(): ) fields = [ - "magnetism.total_magnetization", - "magnetism.total_magnetization_normalized_vol", - "magnetism.total_magnetization_normalized_formula_units", - "magnetism.num_magnetic_sites", - "magnetism.num_unique_magnetic_sites", + "total_magnetization", + "total_magnetization_normalized_vol", + "total_magnetization_normalized_formula_units", + "num_magnetic_sites", + "num_unique_magnetic_sites", ] c = {field: {"$gte": 0, "$lte": 5} for field in fields} - assert q == {"criteria": {"magnetism.ordering": "FM", **c}} + assert q == {"criteria": {"ordering": "FM", **c}} with ScratchDir("."): dumpfn(op, "temp.json") @@ -53,4 +53,4 @@ def test_magnetic_query(): ) c = {field: {"$gte": 0, "$lte": 5} for field in fields} - assert q == {"criteria": {"magnetism.ordering": "FM", **c}} + assert q == {"criteria": {"ordering": "FM", **c}} diff --git a/tests/piezo/test_client.py b/tests/piezo/test_client.py index dc326752..1198068f 100644 --- a/tests/piezo/test_client.py +++ b/tests/piezo/test_client.py @@ -16,13 +16,14 @@ "fields", ] -sub_doc_fields = ["piezo"] # type: list +sub_doc_fields = [] # type: list alt_name_dict = {"piezoelectric_modulus": "e_ij_max"} # type: dict custom_field_tests = {} # type: dict +@pytest.mark.xfail(reason="Needs deployment") @pytest.mark.skipif( os.environ.get("MP_API_KEY", None) is None, reason="No API key found." ) diff --git a/tests/piezo/test_query_operators.py b/tests/piezo/test_query_operators.py index e062d102..cf15536e 100644 --- a/tests/piezo/test_query_operators.py +++ b/tests/piezo/test_query_operators.py @@ -8,12 +8,12 @@ def test_piezo_query(): op = PiezoelectricQuery() assert op.query(piezo_modulus_min=0, piezo_modulus_max=5) == { - "criteria": {"piezo.e_ij_max": {"$gte": 0, "$lte": 5}} + "criteria": {"e_ij_max": {"$gte": 0, "$lte": 5}} } with ScratchDir("."): dumpfn(op, "temp.json") new_op = loadfn("temp.json") assert new_op.query(piezo_modulus_min=0, piezo_modulus_max=5) == { - "criteria": {"piezo.e_ij_max": {"$gte": 0, "$lte": 5}} + "criteria": {"e_ij_max": {"$gte": 0, "$lte": 5}} } diff --git a/tests/robocrys/test_client.py b/tests/robocrys/test_client.py index dbfec9b2..bae8004b 100644 --- a/tests/robocrys/test_client.py +++ b/tests/robocrys/test_client.py @@ -14,7 +14,6 @@ def rester(): rester.session.close() -@pytest.mark.xfail # temp until data fix @pytest.mark.skipif( os.environ.get("MP_API_KEY", None) is None, reason="No API key found." ) @@ -27,10 +26,10 @@ def test_client(rester): if search_method is not None: - q = {"keywords": ["silicon", "process"]} + q = {"keywords": ["silicon"]} doc = search_method(**q)[0] + print(doc) assert doc.description is not None assert doc.condensed_structure is not None - assert doc.task_id is not None diff --git a/tests/test_client.py b/tests/test_client.py index 21c10507..4e0694b3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -26,7 +26,13 @@ "charge_density", ] -ignore_generic = ["_user_settings", "_general_store"] # temp +ignore_generic = [ + "_user_settings", + "_general_store", + "dielectric", + "piezoelectric", + "magnetism", +] # temp mpr = MPRester() diff --git a/tests/test_mprester.py b/tests/test_mprester.py index 3c66b137..239045a7 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -80,6 +80,7 @@ def test_get_structures(self, mpr): structs = mpr.get_structures("Mn3O4", final=False) assert len(structs) > 0 + @pytest.mark.xfail(reason="Until deployment") def test_find_structure(self, mpr): path = os.path.join(MAPISettings().TEST_FILES, "Si_mp_149.cif") with open(path) as file: @@ -245,7 +246,7 @@ def test_query(self, mpr): "total_energy": "energy_per_atom", "formation_energy": "formation_energy_per_atom", "uncorrected_energy": "uncorrected_energy_per_atom", - # "equilibrium_reaction_energy": "equilibrium_reaction_energy_per_atom", + "equilibrium_reaction_energy": "equilibrium_reaction_energy_per_atom", "magnetic_ordering": "ordering", "elastic_anisotropy": "universal_anisotropy", "poisson_ratio": "homogeneous_poisson", @@ -261,10 +262,10 @@ def test_query(self, mpr): "crystal_system": CrystalSystem.cubic, "spacegroup_number": 38, "spacegroup_symbol": "Amm2", - "magnetic_ordering": Ordering.FM, "has_props": ["dielectric"], "theoretical": True, "has_reconstructed": False, + "magnetic_ordering": Ordering.FM, } # type: dict search_method = mpr.query @@ -308,7 +309,12 @@ def test_query(self, mpr): "num_chunks": 1, } - doc = search_method(**q)[0].dict() + docs = search_method(**q) + + if len(docs) > 0: + doc = docs[0].dict() + else: + raise ValueError("No documents returned") assert ( doc[project_field if project_field is not None else param] diff --git a/tests/thermo/test_client.py b/tests/thermo/test_client.py index 246bd917..c7c9b98c 100644 --- a/tests/thermo/test_client.py +++ b/tests/thermo/test_client.py @@ -1,4 +1,5 @@ import os +from pymatgen.analysis.phase_diagram import PhaseDiagram import pytest from mp_api.routes.thermo.client import ThermoRester @@ -106,3 +107,11 @@ def test_client(rester): doc[project_field if project_field is not None else param] is not None ) + + +@pytest.mark.xfail(reason="Temporary until deployment") +def test_get_phase_diagram_from_chemsys(): + # Test that a phase diagram is returned + assert isinstance( + ThermoRester().get_phase_diagram_from_chemsys("Fe-Mn-Pt"), PhaseDiagram + )