Skip to content

Commit

Permalink
concurrent fetching for multiple entities
Browse files Browse the repository at this point in the history
minimal handling of exceptions in concurrent query execution
read_concurrency parameter in Cassandra online store config yaml

Signed-off-by: Stefano Lottini <stefano.lottini@datastax.com>
  • Loading branch information
hemidactylus committed Nov 21, 2022
1 parent 5c9b6fe commit e9c04f9
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 30 deletions.
3 changes: 2 additions & 1 deletion docs/reference/online-stores/cassandra.md
Expand Up @@ -32,6 +32,7 @@ online_store:
load_balancing: # optional
local_dc: 'datacenter1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional
read_concurrency: 100 # optional
```
{% endcode %}

Expand All @@ -52,7 +53,7 @@ online_store:
load_balancing: # optional
local_dc: 'eu-central-1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional

read_concurrency: 100 # optional
```
{% endcode %}

Expand Down
Expand Up @@ -58,6 +58,7 @@ online_store:
load_balancing: # optional
local_dc: 'datacenter1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional
read_concurrency: 100 # optional
```

#### Astra DB setup:
Expand All @@ -84,6 +85,7 @@ online_store:
load_balancing: # optional
local_dc: 'eu-central-1' # optional
load_balancing_policy: 'TokenAwarePolicy(DCAwareRoundRobinPolicy)' # optional
read_concurrency: 100 # optional
```

#### Protocol version and load-balancing settings
Expand Down Expand Up @@ -111,6 +113,14 @@ The former parameter is a region name for Astra DB instances (as can be verified
See the source code of the online store integration for the allowed values of
the latter parameter.

#### Read concurrency value

You can optionally specify the value of `read_concurrency`, which will be
passed to the Cassandra driver function handling
[concurrent reading of multiple entities](https://docs.datastax.com/en/developer/python-driver/3.25/api/cassandra/concurrent/#module-cassandra.concurrent).
Consult the reference for guidance on this parameter (which in most cases can be left to its default value of 100).
This is relevant only for retrieval of several entities at once.

### More info

For a more detailed walkthrough, please see the
Expand Down
Expand Up @@ -30,6 +30,7 @@
ResultSet,
Session,
)
from cassandra.concurrent import execute_concurrent_with_args
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
from cassandra.query import PreparedStatement
from pydantic import StrictFloat, StrictInt, StrictStr
Expand Down Expand Up @@ -166,6 +167,14 @@ class CassandraLoadBalancingPolicy(FeastConfigBaseModel):
wrapped into an execution profile if present.
"""

read_concurrency: Optional[StrictInt] = 100
"""
Value of the `concurrency` parameter internally passed to Cassandra driver's
`execute_concurrent_with_args ` call.
See https://docs.datastax.com/en/developer/python-driver/3.25/api/cassandra/concurrent/#module-cassandra.concurrent .
Default: 100.
"""


class CassandraOnlineStore(OnlineStore):
"""
Expand Down Expand Up @@ -358,32 +367,36 @@ def online_read(

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

for entity_key in entity_keys:
entity_key_bin = serialize_entity_key(
entity_key_bins = [
serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for entity_key in entity_keys
]

with tracing_span(name="remote_call"):
feature_rows_sequence = self._read_rows_by_entity_keys(
config,
project,
table,
entity_key_bins,
columns=["feature_name", "value", "event_ts"],
)

with tracing_span(name="remote_call"):
feature_rows = self._read_rows_by_entity_key(
config,
project,
table,
entity_key_bin,
columns=["feature_name", "value", "event_ts"],
)

for entity_key_bin, feature_rows in zip(entity_key_bins, feature_rows_sequence):
res = {}
res_ts = None
for feature_row in feature_rows:
if (
requested_features is None
or feature_row.feature_name in requested_features
):
val = ValueProto()
val.ParseFromString(feature_row.value)
res[feature_row.feature_name] = val
res_ts = feature_row.event_ts
if feature_rows:
for feature_row in feature_rows:
if (
requested_features is None
or feature_row.feature_name in requested_features
):
val = ValueProto()
val.ParseFromString(feature_row.value)
res[feature_row.feature_name] = val
res_ts = feature_row.event_ts
if not res:
result.append((None, None))
else:
Expand Down Expand Up @@ -479,12 +492,12 @@ def _write_rows(
params,
)

def _read_rows_by_entity_key(
def _read_rows_by_entity_keys(
self,
config: RepoConfig,
project: str,
table: FeatureView,
entity_key_bin: str,
entity_key_bins: List[str],
columns: Optional[List[str]] = None,
) -> ResultSet:
"""
Expand All @@ -500,7 +513,25 @@ def _read_rows_by_entity_key(
fqtable=fqtable,
columns=projection_columns,
)
return session.execute(select_cql, [entity_key_bin])
retrieval_results = execute_concurrent_with_args(
session,
select_cql,
((entity_key_bin,) for entity_key_bin in entity_key_bins),
concurrency=config.online_store.read_concurrency,
)
# execute_concurrent_with_args return a sequence
# of (success, result_or_exception) pairs:
returned_sequence = []
for success, result_or_exception in retrieval_results:
if success:
returned_sequence.append(result_or_exception)
else:
# an exception
logger.error(
f"Cassandra online store exception during concurrent fetching: {str(result_or_exception)}"
)
returned_sequence.append(None)
return returned_sequence

def _drop_table(
self,
Expand Down
27 changes: 21 additions & 6 deletions sdk/python/feast/templates/cassandra/bootstrap.py
Expand Up @@ -70,16 +70,16 @@ def collect_cassandra_store_settings():
sys.exit(1)
needs_port = click.confirm("Need to specify port?", default=False)
if needs_port:
c_port = click.prompt("Port to use", default=9042, type=int)
c_port = click.prompt(" Port to use", default=9042, type=int)
else:
c_port = None
use_auth = click.confirm(
"Do you need username/password?",
default=False,
)
if use_auth:
c_username = click.prompt("Database username")
c_password = click.prompt("Database password", hide_input=True)
c_username = click.prompt(" Database username")
c_password = click.prompt(" Database password", hide_input=True)
else:
c_username = None
c_password = None
Expand All @@ -95,7 +95,7 @@ def collect_cassandra_store_settings():
)
if specify_protocol_version:
c_protocol_version = click.prompt(
"Protocol version",
" Protocol version",
default={"A": 4, "C": 5}.get(db_type, 5),
type=int,
)
Expand All @@ -105,11 +105,11 @@ def collect_cassandra_store_settings():
specify_lb = click.confirm("Specify load-balancing?", default=False)
if specify_lb:
c_local_dc = click.prompt(
"Local datacenter (for load-balancing)",
" Local datacenter (for load-balancing)",
default="datacenter1" if db_type == "C" else None,
)
c_load_balancing_policy = click.prompt(
"Load-balancing policy",
" Load-balancing policy",
type=click.Choice(
[
"TokenAwarePolicy(DCAwareRoundRobinPolicy)",
Expand All @@ -122,6 +122,12 @@ def collect_cassandra_store_settings():
c_local_dc = None
c_load_balancing_policy = None

needs_concurrency = click.confirm("Specify read concurrency level?", default=False)
if needs_concurrency:
c_concurrency = click.prompt(" Concurrency level?", default=100, type=int)
else:
c_concurrency = None

return {
"c_secure_bundle_path": c_secure_bundle_path,
"c_hosts": c_hosts,
Expand All @@ -132,6 +138,7 @@ def collect_cassandra_store_settings():
"c_protocol_version": c_protocol_version,
"c_local_dc": c_local_dc,
"c_load_balancing_policy": c_load_balancing_policy,
"c_concurrency": c_concurrency,
}


Expand All @@ -149,6 +156,7 @@ def apply_cassandra_store_settings(config_file, settings):
'c_protocol_version'
'c_local_dc'
'c_load_balancing_policy'
'c_concurrency'
"""
write_setting_or_remove(
config_file,
Expand Down Expand Up @@ -216,6 +224,13 @@ def apply_cassandra_store_settings(config_file, settings):
remove_lines_from_file(config_file, "load_balancing:")
remove_lines_from_file(config_file, "local_dc:")
remove_lines_from_file(config_file, "load_balancing_policy:")
#
write_setting_or_remove(
config_file,
settings["c_concurrency"],
"read_concurrency",
"100",
)


def bootstrap():
Expand Down
Expand Up @@ -16,4 +16,5 @@ online_store:
load_balancing:
local_dc: c_local_dc
load_balancing_policy: c_load_balancing_policy
read_concurrency: 100
entity_key_serialization_version: 2

0 comments on commit e9c04f9

Please sign in to comment.