Skip to content

Commit

Permalink
feat: Add online_read_async for dynamodb (#4244)
Browse files Browse the repository at this point in the history
  • Loading branch information
robhowley committed Jun 5, 2024
1 parent 15524ce commit b5ef384
Show file tree
Hide file tree
Showing 10 changed files with 790 additions and 156 deletions.
191 changes: 146 additions & 45 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

try:
import boto3
from aiobotocore import session
from boto3.dynamodb.types import TypeDeserializer
from botocore.config import Config
from botocore.exceptions import ClientError
except ImportError as e:
Expand Down Expand Up @@ -80,6 +82,7 @@ class DynamoDBOnlineStore(OnlineStore):

_dynamodb_client = None
_dynamodb_resource = None
_aioboto_session = None

def update(
self,
Expand Down Expand Up @@ -223,69 +226,103 @@ def online_read(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))
batch_result: List[
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
] = []

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = {
table_instance.name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}
batch_entity_ids = self._to_resource_batch_get_payload(
online_config, table_instance.name, batch
)
response = dynamodb_resource.batch_get_item(
RequestItems=batch_entity_ids,
)
response = response.get("Responses")
table_responses = response.get(table_instance.name)
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append(
(datetime.fromisoformat(tbl_res["event_ts"]), res)
)
entity_idx += 1

# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
batch_result = self._process_batch_get_response(
table_instance.name, response, entity_ids, batch
)
result.extend(batch_result)
return result

async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
"""
Reads features values for the given entity keys asynchronously.
Args:
config: The config for the current feature store.
table: The feature view whose feature values should be read.
entity_keys: The list of entity keys for which feature values should be read.
requested_features: The list of features that should be read.
Returns:
A list of the same length as entity_keys. Each item in the list is a tuple where the first
item is the event timestamp for the row, and the second item is a dict mapping feature names
to values, which are returned in proto format.
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
table_name = _get_table_name(online_config, config, table)

deserialize = TypeDeserializer().deserialize

def to_tbl_resp(raw_client_response):
return {
"entity_id": deserialize(raw_client_response["entity_id"]),
"event_ts": deserialize(raw_client_response["event_ts"]),
"values": deserialize(raw_client_response["values"]),
}

async with self._get_aiodynamodb_client(online_config.region) as client:
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = self._to_client_batch_get_payload(
online_config, table_name, batch
)
response = await client.batch_get_item(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp
)
result.extend(batch_result)
return result

def _get_aioboto_session(self):
if self._aioboto_session is None:
self._aioboto_session = session.get_session()
return self._aioboto_session

def _get_aiodynamodb_client(self, region: str):
return self._get_aioboto_session().create_client("dynamodb", region_name=region)

def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None):
if self._dynamodb_client is None:
self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url)
Expand All @@ -298,13 +335,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
)
return self._dynamodb_resource

def _sort_dynamodb_response(self, responses: list, order: list) -> Any:
def _sort_dynamodb_response(
self,
responses: list,
order: list,
to_tbl_response: Callable = lambda raw_dict: raw_dict,
) -> Any:
"""DynamoDB Batch Get Item doesn't return items in a particular order."""
# Assign an index to order
order_with_index = {value: idx for idx, value in enumerate(order)}
# Sort table responses by index
table_responses_ordered: Any = [
(order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses
(order_with_index[tbl_res["entity_id"]], tbl_res)
for tbl_res in map(to_tbl_response, responses)
]
table_responses_ordered = sorted(
table_responses_ordered, key=lambda tup: tup[0]
Expand Down Expand Up @@ -341,6 +384,64 @@ def _write_batch_non_duplicates(
if progress:
progress(1)

def _process_batch_get_response(
self, table_name, response, entity_ids, batch, **sort_kwargs
):
response = response.get("Responses")
table_responses = response.get(table_name)

batch_result = []
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids, **sort_kwargs
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
entity_idx += 1
# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
return batch_result

@staticmethod
def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]):
return [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]

@staticmethod
def _to_resource_batch_get_payload(online_config, table_name, batch):
return {
table_name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}

@staticmethod
def _to_client_batch_get_payload(online_config, table_name, batch):
return {
table_name: {
"Keys": [{"entity_id": {"S": entity_id}} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}


def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
return boto3.client(
Expand Down
Loading

0 comments on commit b5ef384

Please sign in to comment.