Skip to content

Commit

Permalink
Enhance parallel requests and bug fixes (#496)
Browse files Browse the repository at this point in the history
* Add parallel requests to initial page

* Factor our multithread handling

* Fix MPRester import

* Fix pagination with doc num tracking

* Bug fixes pagination and enhance use of threading

* Fix no meta in initial page result

* Ensure initial lone critieria has limit

* Fix synthesis client search

* Ensure limit in robocrys query

* Pytest ignore find structure
  • Loading branch information
munrojm authored Jan 31, 2022
1 parent 667848e commit 560706e
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 67 deletions.
1 change: 1 addition & 0 deletions src/mp_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
""" Primary MAPI module """
import os
from pkg_resources import get_distribution, DistributionNotFound
from mp_api.client import MPRester

try: # pragma: no cover
from setuptools_scm import get_version
Expand Down
202 changes: 137 additions & 65 deletions src/mp_api/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from os import environ
from typing import Dict, Generic, List, Optional, TypeVar, Union, Tuple
from urllib.parse import urljoin
import operator
from copy import copy
from math import ceil
from matplotlib import use

import requests
from emmet.core.utils import jsanitize
Expand Down Expand Up @@ -353,12 +355,12 @@ def _submit_requests(

# Get new limit values that sum to chunk_size
num_new_params = len(new_param_values)
q = int(chunk_size / num_new_params)
r = chunk_size % num_new_params
q = int(chunk_size / num_new_params) # quotient
r = chunk_size % num_new_params # remainder
new_limits = []

for _ in range(num_new_params):
val = q + 1 if r > 0 else q
val = q + 1 if r > 0 else q if q > 0 else 1
new_limits.append(val)
r -= 1

Expand Down Expand Up @@ -391,30 +393,34 @@ def _submit_requests(
# new limit value we assigned.
subtotals = []
remaining_docs_avail = {}
for crit_ind, crit in enumerate(new_criteria):

# Check how much pagination is needed
response = self.session.get(url, verify=True, params=crit)
initial_params_list = [
{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria
]

data, subtotal = self._handle_response(response, use_document_model)
subtotals.append(subtotal)
initial_data_tuples = self._multi_thread(
use_document_model, initial_params_list
)

sub_diff = subtotal - new_limits[crit_ind]
for data, subtotal, crit_ind in initial_data_tuples:

subtotals.append(subtotal)
sub_diff = subtotal - new_limits[crit_ind]
remaining_docs_avail[crit_ind] = sub_diff

total_data["data"].extend(data["data"])

# Rebalance if some parallel queries produced too few results
if len(remaining_docs_avail) > 1:
last_data_entry = initial_data_tuples[-1][0]

# Rebalance if some parallel queries produced too few results
if len(remaining_docs_avail) > 1 and len(total_data["data"]) < chunk_size:
remaining_docs_avail = dict(
sorted(remaining_docs_avail.items(), key=lambda item: item[1])
)

# Redistribute missing docs from initial chunk among queries
# which have head room with respect to remaining document number.
fill_docs = 0
rebalance_params = []
for crit_ind, amount_avail in remaining_docs_avail.items():
if amount_avail <= 0:
fill_docs += abs(amount_avail)
Expand All @@ -423,6 +429,9 @@ def _submit_requests(
crit = new_criteria[crit_ind]
crit["skip"] = crit["limit"]

if fill_docs == 0:
continue

if fill_docs >= amount_avail:
crit["limit"] = amount_avail
new_limits[crit_ind] += amount_avail
Expand All @@ -433,43 +442,60 @@ def _submit_requests(
new_limits[crit_ind] += fill_docs
fill_docs = 0

response = self.session.get(url, verify=True, params=crit)
data, _ = self._handle_response(response, use_document_model)
total_data["data"].extend(data["data"])
rebalance_params.append(
{"url": url, "verify": True, "params": copy(crit)}
)

new_criteria[crit_ind]["skip"] += crit["limit"]
new_criteria[crit_ind]["limit"] = chunk_size

if fill_docs == 0:
break
# Obtain missing initial data after rebalancing
if len(rebalance_params) > 0:

total_num_docs = sum(subtotals)
rebalance_data_tuples = self._multi_thread(
use_document_model, rebalance_params
)

if "meta" in data:
data["meta"]["total_doc"] = total_num_docs
total_data["meta"] = data["meta"]
for data, _, _ in rebalance_data_tuples:
total_data["data"].extend(data["data"])

# If we have all the results in a single page, return directly
if len(total_data["data"]) == total_num_docs or num_chunks == 1:
return total_data
last_data_entry = rebalance_data_tuples[-1][0]

if chunk_size is None:
raise ValueError("A chunk size must be provided to enable pagination")
total_num_docs = sum(subtotals)

if "meta" in last_data_entry:
last_data_entry["meta"]["total_doc"] = total_num_docs
total_data["meta"] = last_data_entry["meta"]

# otherwise prepare to paginate in parallel
# Get max number of reponse pages
max_pages = (
num_chunks
if num_chunks is not None
else (int(total_num_docs / chunk_size) + 1)
num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)
)

if num_chunks is not None:
total_num_docs = min(len(total_data["data"]) * num_chunks, total_num_docs)
# Get total number of docs needed
num_docs_needed = min((max_pages * chunk_size), total_num_docs)

# Setup progress bar
t = tqdm(
pbar = tqdm(
desc=f"Retrieving {self.document_model.__name__} documents", # type: ignore
total=total_num_docs,
total=num_docs_needed,
)
t.update(len(total_data["data"]))

initial_data_length = len(total_data["data"])

# If we have all the results in a single page, return directly
if initial_data_length >= num_docs_needed or num_chunks == 1:
new_total_data = copy(total_data)
new_total_data["data"] = total_data["data"][:num_docs_needed]
pbar.update(num_docs_needed)
pbar.close()
return new_total_data

# otherwise, prepare to paginate in parallel
if chunk_size is None:
raise ValueError("A chunk size must be provided to enable pagination")

pbar.update(initial_data_length)

# Warning to select specific fields only for many results
if criteria.get("all_fields", False) and (total_num_docs / chunk_size > 10):
Expand All @@ -481,64 +507,110 @@ def _submit_requests(

# Get all pagination input params for parallel requests
params_list = []
exit = False

for page_num in range(0, max_pages - 1):
for crit_num, crit in enumerate(new_criteria):
doc_counter = 0

if new_limits[crit_num] == 0:
continue
for crit_num, crit in enumerate(new_criteria):
remaining = remaining_docs_avail[crit_num]
if "skip" not in crit:
crit["skip"] = chunk_size if "limit" not in crit else crit["limit"]

if (
num_chunks is not None
and (((page_num + 1) * (crit_num + 1))) == num_chunks
):
exit = True
while remaining > 0:
if doc_counter == (num_docs_needed - initial_data_length):
break

skip = new_limits[crit_num] + int(page_num * chunk_size)
if remaining < chunk_size:
crit["limit"] = remaining
doc_counter += remaining
else:
n = chunk_size - (doc_counter % chunk_size)
crit["limit"] = n
doc_counter += n

params_list.append(
{"url": url, "verify": True, "params": {**crit, "skip": skip}}
{
"url": url,
"verify": True,
"params": {**crit, "skip": crit["skip"]},
}
)

if exit:
break
crit["skip"] += crit["limit"]
remaining -= crit["limit"]

# Submit requests and process data
data_tuples = self._multi_thread(use_document_model, params_list, pbar)

for data, _, _ in data_tuples:
total_data["data"].extend(data["data"])

if "meta" in data:
total_data["meta"]["time_stamp"] = data["meta"]["time_stamp"]

pbar.close()

return total_data

def _multi_thread(
self,
use_document_model: bool,
params_list: List[dict],
progress_bar: tqdm = None,
):
"""
Handles setting up a threadpool and sending parallel requests
Arguments:
use_document_model (bool): if None, will defer to the self.use_document_model attribute
params_list (list): list of dictionaries containing url and params for each request
progress_bar (tqdm): progress bar to update with progress
Returns:
Tuples with data, total number of docs in matching the query in the database,
and the index of the criteria dictionary in the provided parameter list
"""

return_data = []

params_gen = iter(
params_list
) # Iter necessary for islice to keep track of what has been accessed

params_ind = 0

with ThreadPoolExecutor(
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS
) as executor:

# Get list of initial futures defined by max number of parallel requests
futures = {
executor.submit(self.session.get, **params)
for params in itertools.islice(
params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS
)
}
futures = set({})
for params in itertools.islice(
params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS
):

future = executor.submit(self.session.get, **params)
setattr(future, "crit_ind", params_ind)
futures.add(future)
params_ind += 1

while futures:
# Wait for at least one future to complete and process finished
finished, futures = wait(futures, return_when=FIRST_COMPLETED)

for future in finished:
response = future.result()
data, _ = self._handle_response(response, use_document_model)
t.update(len(data["data"]))
total_data["data"].extend(data["data"])
data, subtotal = self._handle_response(response, use_document_model)
if progress_bar is not None:
progress_bar.update(len(data["data"]))
return_data.append((data, subtotal, future.crit_ind)) # type: ignore

# Populate more futures to replace finished
for params in itertools.islice(params_gen, len(finished)):
futures.add(executor.submit(self.session.get, **params))
new_future = executor.submit(self.session.get, **params)
setattr(new_future, "crit_ind", params_ind)
futures.add(new_future)
params_ind += 1

if "meta" in data:
total_data["meta"]["time_stamp"] = data["meta"]["time_stamp"]

return total_data
return return_data

def _handle_response(
self, response: requests.Response, use_document_model: bool
Expand Down
2 changes: 1 addition & 1 deletion src/mp_api/routes/robocrys.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def search_robocrys_text(
keyword_string = ",".join(keywords)

robocrys_docs = self._query_resource(
criteria={"keywords": keyword_string},
criteria={"keywords": keyword_string, "limit": chunk_size},
suburl="text_search",
use_document_model=True,
chunk_size=100,
Expand Down
3 changes: 2 additions & 1 deletion src/mp_api/routes/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def search_synthesis_text(
condition_mixing_device: Optional[List[str]] = None,
condition_mixing_media: Optional[List[str]] = None,
num_chunks: Optional[int] = None,
chunk_size: Optional[int] = 100,
chunk_size: Optional[int] = 10,
):
"""
Search synthesis recipe text.
Expand Down Expand Up @@ -72,6 +72,7 @@ def search_synthesis_text(
"condition_heating_atmosphere": condition_heating_atmosphere,
"condition_mixing_device": condition_mixing_device,
"condition_mixing_media": condition_mixing_media,
"limit": chunk_size,
},
chunk_size=chunk_size,
num_chunks=num_chunks,
Expand Down
1 change: 1 addition & 0 deletions tests/test_mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_get_structures(self, mpr):
structs = mpr.get_structures("Mn-O", final=False)
assert len(structs) > 0

@pytest.mark.skip(reason="endpoint issues")
def test_find_structure(self, mpr):
path = os.path.join(MAPIClientSettings().TEST_FILES, "Si_mp_149.cif")
with open(path) as file:
Expand Down

0 comments on commit 560706e

Please sign in to comment.