Skip to content

Add support for CASSANDRA-7660 #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,13 +1534,13 @@ def prepare(self, query):
future = ResponseFuture(self, message, query=None)
try:
future.send_request()
query_id, column_metadata = future.result(self.default_timeout)
query_id, column_metadata, pk_indexes = future.result(self.default_timeout)
except Exception:
log.exception("Error preparing query:")
raise

prepared_statement = PreparedStatement.from_message(
query_id, column_metadata, self.cluster.metadata, query, self.keyspace,
query_id, column_metadata, pk_indexes, self.cluster.metadata, query, self.keyspace,
self._protocol_version)

host = future._current_host
Expand Down
35 changes: 31 additions & 4 deletions cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def recv_body(cls, f, protocol_version, user_type_map):
ksname = read_string(f)
results = ksname
elif kind == RESULT_KIND_PREPARED:
results = cls.recv_results_prepared(f, user_type_map)
results = cls.recv_results_prepared(f, protocol_version, user_type_map)
elif kind == RESULT_KIND_SCHEMA_CHANGE:
results = cls.recv_results_schema_change(f, protocol_version)
return cls(kind, results, paging_state)
Expand All @@ -578,16 +578,17 @@ def recv_results_rows(cls, f, protocol_version, user_type_map):
return (paging_state, (colnames, parsed_rows))

@classmethod
def recv_results_prepared(cls, f, user_type_map):
def recv_results_prepared(cls, f, protocol_version, user_type_map):
query_id = read_binary_string(f)
_, column_metadata = cls.recv_results_metadata(f, user_type_map)
return (query_id, column_metadata)
column_metadata, pk_indexes = cls.recv_prepared_metadata(f, protocol_version, user_type_map)
return (query_id, column_metadata, pk_indexes)

@classmethod
def recv_results_metadata(cls, f, user_type_map):
flags = read_int(f)
glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
colcount = read_int(f)

if flags & cls._HAS_MORE_PAGES_FLAG:
paging_state = read_binary_longstring(f)
else:
Expand All @@ -608,6 +609,32 @@ def recv_results_metadata(cls, f, user_type_map):
column_metadata.append((colksname, colcfname, colname, coltype))
return paging_state, column_metadata

@classmethod
def recv_prepared_metadata(cls, f, protocol_version, user_type_map):
flags = read_int(f)
glob_tblspec = bool(flags & cls._FLAGS_GLOBAL_TABLES_SPEC)
colcount = read_int(f)
pk_indexes = None
if protocol_version >= 4:
num_pk_indexes = read_int(f)
pk_indexes = [read_short(f) for _ in range(num_pk_indexes)]

if glob_tblspec:
ksname = read_string(f)
cfname = read_string(f)
column_metadata = []
for _ in range(colcount):
if glob_tblspec:
colksname = ksname
colcfname = cfname
else:
colksname = read_string(f)
colcfname = read_string(f)
colname = read_string(f)
coltype = cls.read_type(f, user_type_map)
column_metadata.append((colksname, colcfname, colname, coltype))
return column_metadata, pk_indexes

@classmethod
def recv_results_schema_change(cls, f, protocol_version):
return EventMessage.recv_schema_change(f, protocol_version)
Expand Down
43 changes: 23 additions & 20 deletions cassandra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,29 +353,32 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query, keyspa
self.fetch_size = fetch_size

@classmethod
def from_message(cls, query_id, column_metadata, cluster_metadata, query, prepared_keyspace, protocol_version):
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata, query, prepared_keyspace, protocol_version):
if not column_metadata:
return PreparedStatement(column_metadata, query_id, None, query, prepared_keyspace, protocol_version)

partition_key_columns = None
routing_key_indexes = None

ks_name, table_name, _, _ = column_metadata[0]
ks_meta = cluster_metadata.keyspaces.get(ks_name)
if ks_meta:
table_meta = ks_meta.tables.get(table_name)
if table_meta:
partition_key_columns = table_meta.partition_key

# make a map of {column_name: index} for each column in the statement
statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata))

# a list of which indexes in the statement correspond to partition key items
try:
routing_key_indexes = [statement_indexes[c.name]
for c in partition_key_columns]
except KeyError: # we're missing a partition key component in the prepared
pass # statement; just leave routing_key_indexes as None
if pk_indexes:
routing_key_indexes = pk_indexes
else:
partition_key_columns = None
routing_key_indexes = None

ks_name, table_name, _, _ = column_metadata[0]
ks_meta = cluster_metadata.keyspaces.get(ks_name)
if ks_meta:
table_meta = ks_meta.tables.get(table_name)
if table_meta:
partition_key_columns = table_meta.partition_key

# make a map of {column_name: index} for each column in the statement
statement_indexes = dict((c[2], i) for i, c in enumerate(column_metadata))

# a list of which indexes in the statement correspond to partition key items
try:
routing_key_indexes = [statement_indexes[c.name]
for c in partition_key_columns]
except KeyError: # we're missing a partition key component in the prepared
pass # statement; just leave routing_key_indexes as None

return PreparedStatement(column_metadata, query_id, routing_key_indexes,
query, prepared_keyspace, protocol_version)
Expand Down