Skip to content

Commit

Permalink
Molecules client update
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed May 11, 2021
1 parent b26f32f commit c58eec6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 41 deletions.
55 changes: 14 additions & 41 deletions src/mp_api/routes/molecules/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pymatgen.core.periodic_table import Element

from mp_api.core.client import BaseRester, MPRestError
from mp_api.core.client import BaseRester
from mp_api.routes.molecules.models import MoleculesDoc


Expand All @@ -13,24 +13,6 @@ class MoleculesRester(BaseRester):
suffix = "molecules"
document_model = MoleculesDoc # type: ignore

def get_molecule_from_molecule_id(self, molecule_id: str):
"""
Get molecule data for a given Materials Project molecule ID.
Arguments:
molecule_id (str): Materials project molecule ID
Returns:
results (Dict): Dictionary containing molecule data.
"""

result = self._make_request("{}/?all_fields=true".format(molecule_id))

if len(result.get("data", [])) > 0:
return result
else:
raise MPRestError("No document found")

def search_molecules_docs(
self,
elements: Optional[List[Element]] = None,
Expand All @@ -41,7 +23,8 @@ def search_molecules_docs(
pointgroup: Optional[str] = None,
smiles: Optional[str] = None,
num_chunks: Optional[int] = None,
chunk_size: int = 100,
chunk_size: int = 1000,
all_fields: bool = True,
fields: Optional[List[str]] = None,
):
"""
Expand All @@ -57,20 +40,16 @@ def search_molecules_docs(
smiles (str): The simplified molecular input line-entry system (SMILES) representation of the molecule.
num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible.
chunk_size (int): Number of data entries per chunk.
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in SubstratesDoc to return data for.
Default is the film_id and substrate_id only.
Yields:
([dict]) List of dictionaries containing data for entries defined in 'fields'.
Defaults to Materials Project molecule IDs.
Returns:
([MoleculesDoc]) List of molecule documents
"""

query_params = defaultdict(dict) # type: dict

if chunk_size <= 0 or chunk_size > 100:
warnings.warn("Improper chunk size given. Setting value to 100.")
chunk_size = 100

if elements:
query_params.update({"elements": ",".join([str(ele) for ele in elements])})

Expand All @@ -94,23 +73,17 @@ def search_molecules_docs(
if charge:
query_params.update({"charge_min": charge[0], "charge_max": charge[1]})

if fields:
query_params.update({"fields": ",".join(fields)})

query_params = {
entry: query_params[entry]
for entry in query_params
if query_params[entry] is not None
}

query_params.update({"limit": chunk_size, "skip": 0})
count = 0
while True:
query_params["skip"] = count * chunk_size
results = self.query(query_params).get("data", [])

if not any(results) or (num_chunks is not None and count == num_chunks):
break

count += 1
yield results
return super().search(
version=self.version,
num_chunks=num_chunks,
chunk_size=chunk_size,
all_fields=all_fields,
fields=fields,
**query_params
)
13 changes: 13 additions & 0 deletions src/mp_api/routes/molecules/client.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List, Optional
from mp_api.routes.molecules.models import MoleculesDoc


class MoleculeRester:

def get_document_by_id(
self,
document_id: str,
fields: Optional[List[str]] = None,
monty_decode: bool = True,
version: Optional[str] = None,
) -> MoleculesDoc: ...

0 comments on commit c58eec6

Please sign in to comment.