diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index c0494272b34..a6cbfb41d2b 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import itertools import logging from datetime import datetime @@ -297,7 +298,6 @@ async def online_read_async( batch_size = online_config.batch_size entity_ids = self._to_entity_ids(config, entity_keys) entity_ids_iter = iter(entity_ids) - result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] table_name = _get_table_name(online_config, config, table) deserialize = TypeDeserializer().deserialize @@ -309,24 +309,40 @@ def to_tbl_resp(raw_client_response): "values": deserialize(raw_client_response["values"]), } + batches = [] + entity_id_batches = [] + while True: + batch = list(itertools.islice(entity_ids_iter, batch_size)) + if not batch: + break + entity_id_batch = self._to_client_batch_get_payload( + online_config, table_name, batch + ) + batches.append(batch) + entity_id_batches.append(entity_id_batch) + async with self._get_aiodynamodb_client(online_config.region) as client: - while True: - batch = list(itertools.islice(entity_ids_iter, batch_size)) - - # No more items to insert - if len(batch) == 0: - break - batch_entity_ids = self._to_client_batch_get_payload( - online_config, table_name, batch - ) - response = await client.batch_get_item( - RequestItems=batch_entity_ids, - ) - batch_result = self._process_batch_get_response( - table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp - ) - result.extend(batch_result) - return result + response_batches = await asyncio.gather( + *[ + client.batch_get_item( + RequestItems=entity_id_batch, + ) + for entity_id_batch in entity_id_batches + ] + ) + + result_batches = [] + for batch, response in zip(batches, response_batches): + result_batch = self._process_batch_get_response( + table_name, + response, + entity_ids, + batch, + to_tbl_response=to_tbl_resp, + ) + result_batches.append(result_batch) + + return list(itertools.chain(*result_batches)) def _get_aioboto_session(self): if self._aioboto_session is None: diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index ea86bd9175a..cf2d68eb746 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import asyncio from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union @@ -240,7 +240,7 @@ async def get_online_features_async( native_entity_values=True, ) - for table, requested_features in grouped_refs: + async def query_table(table, requested_features): # Get the correct set of entity values with the correct join keys. table_entity_values, idxs = utils._get_unique_entities( table, @@ -258,6 +258,18 @@ async def get_online_features_async( requested_features=requested_features, ) + return idxs, read_rows + + all_responses = await asyncio.gather( + *[ + query_table(table, requested_features) + for table, requested_features in grouped_refs + ] + ) + + for (idxs, read_rows), (table, requested_features) in zip( + all_responses, grouped_refs + ): feature_data = utils._convert_rows_to_protobuf( requested_features, read_rows )