Skip to content

Commit

Permalink
Merge pull request #1114 from haaawk/stream_ids_fix
Browse files Browse the repository at this point in the history
Stop reusing stream ids of requests that have timed out due to client-side timeout (#1114)

* ResponseFuture: do not return the stream ID on client timeout

When a timeout occurs, the ResponseFuture associated with the query
returns its stream ID to the associated connection's free stream ID pool
- so that the stream ID can be immediately reused by another query.

However, that it incorrect and dangerous. If query A times out before it
receives a response from the cluster, a different query B might be
issued on the same connection and stream. If response for query A
arrives earlier than the response for query B, the first one might be
misinterpreted as the response for query B.

This commit changes the logic so that stream IDs are not returned on
timeout - now, they are only returned after receiving a response.

* Connection: fix tracking of in_flight requests

This commit fixes tracking of in_flight requests. Before it, in case of
a client-side timeout, the response ID was not returned to the pool, but
the in_flight counter was decremented anyway. This counter is used to
determine if there is a need to wait for stream IDs to be freed -
without this patch, it could happen that the driver throught that it can
initiate another request due to in_flight counter being low, but there
weren't any free stream IDs to allocate, so an assertion was triggered
and the connection was defuncted and opened again.

Now, requests timed out on the client side are tracked in the
orphaned_request_ids field, and the in_flight counter is decremented
only after the response is received.

* Connection: notify owning pool about released orphaned streams

Before this patch, the following situation could occur:

1. On a single connection, multiple requests are spawned up to the
   maximum concurrency,
2. We want to issue more requests but we need to wait on a condition
   variable because requests spawned in 1. took all stream IDs and we
   need to wait until some of them are freed,
3. All requests from point 1. time out on the client side - we cannot
   free their stream IDs until the database node responds,
4. Responses for requests issued in point 1. arrive, but the Connection
   class has no access to the condition variable mentioned in point 2.,
   so no requests from point 2. are admitted,
5. Requests from point 2. waiting on the condition variable time out
   despite there are stream IDs available.

This commit adds an _on_orphaned_stream_released field to the Connection class, and now
it notifies the owning pool in case a timed out request receives a late
response and a stream ID is freed by calling _on_orphaned_stream_released
callback.

* HostConnection: implement replacing overloaded connections

In a situation of very high overload or poor networking conditions, it
might happen that there is a large number of outstanding requests on a
single connection. Each request reserves a stream ID which cannot be
reused until a response for it arrives, even if the request already
timed out on the client side. Because the pool of available stream IDs
for a single connection is limited, such situation might cause the set
of free stream IDs to shrink to a very small size (including zero),
which will drastically reduce the available concurrency on
the connection, or even render it unusable for some time.

In order to prevent this, the following strategy is adopted: when the
number of orphaned stream IDs reaches a certain threshold (e.g. 75% of
all available stream IDs), the connection becomes marked as overloaded.
Meanwhile, a new connection is opened - when it becomes available, it
replaces the old one, and the old connection is moved to "trash" where
it waits until all its outstanding requests either respond or time out.

This feature is implemented for HostConnection but not for
HostConnectionPool, which means that it will only work for clusters
which use protocol v3 or newer.

This fix is heavily inspired by the fix for JAVA-1519.

Co-authored-by: Piotr Dulikowski <piodul@scylladb.com>
  • Loading branch information
haaawk and piodul committed Nov 23, 2021
1 parent 1759428 commit 387150a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 30 deletions.
11 changes: 9 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4361,10 +4361,17 @@ def _on_timeout(self, _attempts=0):

pool = self.session._pools.get(self._current_host)
if pool and not pool.is_shutdown:
# Do not return the stream ID to the pool yet. We cannot reuse it
# because the node might still be processing the query and will
# return a late response to that query - if we used such stream
# before the response to the previous query has arrived, the new
# query could get a response from the old query
with self._connection.lock:
self._connection.request_ids.append(self._req_id)
self._connection.orphaned_request_ids.add(self._req_id)
if len(self._connection.orphaned_request_ids) >= self._connection.orphaned_threshold:
self._connection.orphaned_threshold_reached = True

pool.return_connection(self._connection)
pool.return_connection(self._connection, stream_was_orphaned=True)

errors = self._errors
if not errors:
Expand Down
32 changes: 31 additions & 1 deletion cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ class Connection(object):

# The current number of operations that are in flight. More precisely,
# the number of request IDs that are currently in use.
# This includes orphaned requests.
in_flight = 0

# Max concurrent requests allowed per connection. This is set optimistically high, allowing
Expand All @@ -707,6 +708,20 @@ class Connection(object):
# request_ids set
highest_request_id = 0

# Tracks the request IDs which are no longer waited on (timed out), but
# cannot be reused yet because the node might still send a response
# on this stream
orphaned_request_ids = None

# Set to true if the orphaned stream ID count cross configured threshold
# and the connection will be replaced
orphaned_threshold_reached = False

# If the number of orphaned streams reaches this threshold, this connection
# will become marked and will be replaced with a new connection by the
# owning pool (currently, only HostConnection supports this)
orphaned_threshold = 3 * max_in_flight // 4

is_defunct = False
is_closed = False
lock = None
Expand All @@ -733,6 +748,8 @@ class Connection(object):

_is_checksumming_enabled = False

_on_orphaned_stream_released = None

@property
def _iobuf(self):
# backward compatibility, to avoid any change in the reactors
Expand All @@ -742,7 +759,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression=True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None):
ssl_context=None, on_orphaned_stream_released=None):

# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)
Expand All @@ -764,6 +781,8 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
self._io_buffer = _ConnectionIOBuffer(self)
self._continuous_paging_sessions = {}
self._socket_writable = True
self.orphaned_request_ids = set()
self._on_orphaned_stream_released = on_orphaned_stream_released

if ssl_options:
self._check_hostname = bool(self.ssl_options.pop('check_hostname', False))
Expand Down Expand Up @@ -1188,11 +1207,22 @@ def process_msg(self, header, body):
decoder = paging_session.decoder
result_metadata = None
else:
need_notify_of_release = False
with self.lock:
if stream_id in self.orphaned_request_ids:
self.in_flight -= 1
self.orphaned_request_ids.remove(stream_id)
need_notify_of_release = True
if need_notify_of_release and self._on_orphaned_stream_released:
self._on_orphaned_stream_released()

try:
callback, decoder, result_metadata = self._requests.pop(stream_id)
# This can only happen if the stream_id was
# removed due to an OperationTimedOut
except KeyError:
with self.lock:
self.request_ids.append(stream_id)
return

try:
Expand Down
97 changes: 80 additions & 17 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ def __init__(self, host, host_distance, session):
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
self._stream_available_condition = Condition(self._lock)
self._is_replacing = False
# Contains connections which shouldn't be used anymore
# and are waiting until all requests time out or complete
# so that we can dispose of them.
self._trash = set()

if host_distance == HostDistance.IGNORED:
log.debug("Not opening connection to ignored host %s", self.host)
Expand All @@ -399,42 +403,59 @@ def __init__(self, host, host_distance, session):
return

log.debug("Initializing connection for host %s", self.host)
self._connection = session.cluster.connection_factory(host.endpoint)
self._connection = session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
self._keyspace = session.keyspace
if self._keyspace:
self._connection.set_keyspace_blocking(self._keyspace)
log.debug("Finished initializing connection for host %s", self.host)

def borrow_connection(self, timeout):
def _get_connection(self):
if self.is_shutdown:
raise ConnectionException(
"Pool for %s is shutdown" % (self.host,), self.host)

conn = self._connection
if not conn:
raise NoConnectionsAvailable()
return conn

def borrow_connection(self, timeout):
conn = self._get_connection()
if conn.orphaned_threshold_reached:
with self._lock:
if not self._is_replacing:
self._is_replacing = True
self._session.submit(self._replace, conn)
log.debug(
"Connection to host %s reached orphaned stream limit, replacing...",
self.host
)

start = time.time()
remaining = timeout
while True:
with conn.lock:
if conn.in_flight < conn.max_request_id:
if not (conn.orphaned_threshold_reached and conn.is_closed) and conn.in_flight < conn.max_request_id:
conn.in_flight += 1
return conn, conn.get_request_id()
if timeout is not None:
remaining = timeout - time.time() + start
if remaining < 0:
break
with self._stream_available_condition:
self._stream_available_condition.wait(remaining)
if conn.orphaned_threshold_reached and conn.is_closed:
conn = self._get_connection()
else:
self._stream_available_condition.wait(remaining)

raise NoConnectionsAvailable("All request IDs are currently in use")

def return_connection(self, connection):
with connection.lock:
connection.in_flight -= 1
with self._stream_available_condition:
self._stream_available_condition.notify()
def return_connection(self, connection, stream_was_orphaned=False):
if not stream_was_orphaned:
with connection.lock:
connection.in_flight -= 1
with self._stream_available_condition:
self._stream_available_condition.notify()

if connection.is_defunct or connection.is_closed:
if connection.signaled_error and not self.shutdown_on_error:
Expand All @@ -461,6 +482,24 @@ def return_connection(self, connection):
return
self._is_replacing = True
self._session.submit(self._replace, connection)
else:
if connection in self._trash:
with connection.lock:
if connection.in_flight == len(connection.orphaned_request_ids):
with self._lock:
if connection in self._trash:
self._trash.remove(connection)
log.debug("Closing trashed connection (%s) to %s", id(connection), self.host)
connection.close()
return

def on_orphaned_stream_released(self):
"""
Called when a response for an orphaned stream (timed out on the client
side) was received.
"""
with self._stream_available_condition:
self._stream_available_condition.notify()

def _replace(self, connection):
with self._lock:
Expand All @@ -469,17 +508,23 @@ def _replace(self, connection):

log.debug("Replacing connection (%s) to %s", id(connection), self.host)
try:
conn = self._session.cluster.connection_factory(self.host.endpoint)
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._keyspace)
self._connection = conn
except Exception:
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
self._session.submit(self._replace, connection)
else:
with self._lock:
self._is_replacing = False
self._stream_available_condition.notify()
with connection.lock:
with self._lock:
if connection.orphaned_threshold_reached:
if connection.in_flight == len(connection.orphaned_request_ids):
connection.close()
else:
self._trash.add(connection)
self._is_replacing = False
self._stream_available_condition.notify()

def shutdown(self):
with self._lock:
Expand All @@ -493,6 +538,16 @@ def shutdown(self):
self._connection.close()
self._connection = None

trash_conns = None
with self._lock:
if self._trash:
trash_conns = self._trash
self._trash = set()

if trash_conns is not None:
for conn in self._trash:
conn.close()

def _set_keyspace_for_all_conns(self, keyspace, callback):
if self.is_shutdown or not self._connection:
return
Expand Down Expand Up @@ -548,7 +603,7 @@ def __init__(self, host, host_distance, session):

log.debug("Initializing new connection pool for host %s", self.host)
core_conns = session.cluster.get_core_connections_per_host(host_distance)
self._connections = [session.cluster.connection_factory(host.endpoint)
self._connections = [session.cluster.connection_factory(host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
for i in range(core_conns)]

self._keyspace = session.keyspace
Expand Down Expand Up @@ -652,7 +707,7 @@ def _add_conn_if_under_max(self):

log.debug("Going to open new connection to host %s", self.host)
try:
conn = self._session.cluster.connection_factory(self.host.endpoint)
conn = self._session.cluster.connection_factory(self.host.endpoint, on_orphaned_stream_released=self.on_orphaned_stream_released)
if self._keyspace:
conn.set_keyspace_blocking(self._session.keyspace)
self._next_trash_allowed_at = time.time() + _MIN_TRASH_INTERVAL
Expand Down Expand Up @@ -712,9 +767,10 @@ def _wait_for_conn(self, timeout):

raise NoConnectionsAvailable()

def return_connection(self, connection):
def return_connection(self, connection, stream_was_orphaned=False):
with connection.lock:
connection.in_flight -= 1
if not stream_was_orphaned:
connection.in_flight -= 1
in_flight = connection.in_flight

if connection.is_defunct or connection.is_closed:
Expand Down Expand Up @@ -750,6 +806,13 @@ def return_connection(self, connection):
else:
self._signal_available_conn()

def on_orphaned_stream_released(self):
"""
Called when a response for an orphaned stream (timed out on the client
side) was received.
"""
self._signal_available_conn()

def _maybe_trash_connection(self, connection):
core_conns = self._session.cluster.get_core_connections_per_host(self.host_distance)
did_trash = False
Expand Down
Binary file added tests/unit/.noseids
Binary file not shown.
20 changes: 10 additions & 10 deletions tests/unit/test_host_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from cassandra.pool import Host, NoConnectionsAvailable
from cassandra.policies import HostDistance, SimpleConvictionPolicy


class _PoolTests(unittest.TestCase):
PoolImpl = None
uses_single_connection = None
Expand All @@ -45,7 +44,7 @@ def test_borrow_and_return(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

c, request_id = pool.borrow_connection(timeout=0.01)
self.assertIs(c, conn)
Expand All @@ -64,7 +63,7 @@ def test_failed_wait_for_connection(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
self.assertEqual(1, conn.in_flight)
Expand All @@ -82,7 +81,7 @@ def test_successful_wait_for_connection(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
self.assertEqual(1, conn.in_flight)
Expand Down Expand Up @@ -110,7 +109,7 @@ def test_spawn_when_at_max(self):
session.cluster.get_max_connections_per_host.return_value = 2

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
self.assertEqual(1, conn.in_flight)
Expand All @@ -133,7 +132,7 @@ def test_return_defunct_connection(self):
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
conn.is_defunct = True
Expand All @@ -148,11 +147,12 @@ def test_return_defunct_connection_on_down_host(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=False,
max_request_id=100, signaled_error=False)
max_request_id=100, signaled_error=False,
orphaned_threshold_reached=False)
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
conn.is_defunct = True
Expand All @@ -169,11 +169,11 @@ def test_return_closed_connection(self):
host = Mock(spec=Host, address='ip1')
session = self.make_session()
conn = NonCallableMagicMock(spec=Connection, in_flight=0, is_defunct=False, is_closed=True, max_request_id=100,
signaled_error=False)
signaled_error=False, orphaned_threshold_reached=False)
session.cluster.connection_factory.return_value = conn

pool = self.PoolImpl(host, HostDistance.LOCAL, session)
session.cluster.connection_factory.assert_called_once_with(host.endpoint)
session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released)

pool.borrow_connection(timeout=0.01)
conn.is_closed = True
Expand Down
Loading

0 comments on commit 387150a

Please sign in to comment.