Skip to content

Commit

Permalink
Merge PRs #313 and #316
Browse files Browse the repository at this point in the history
This change addresses a regression reported in #312 that has been
introduced in v8.4.4 via #309. Before the fix, the code attempted to
modify a selectors map while looping over it that caused an exception
to be raised. It was "RuntimeError: dictionary changed size during
iteration".

Now its contents are being copied before iteration preventing looping
over a moving target.
  • Loading branch information
webknjaz committed Aug 24, 2020
3 parents b6b777c + c94eaab + a4cfe02 commit 7644616
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 16 deletions.
12 changes: 12 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
.. scm-version-title:: v8.4.5

- :issue:`312` via :pr:`313`: Fixed a regression introduced
in the earlier refactoring in v8.4.4 via :pr:`309` that
caused the connection manager to modify the selector map
while looping over it -- by :user:`liamstask`.

- :issue:`312` via :pr:`316`: Added a regression test for
the error handling in :py:meth:`~cheroot.connections.\
ConnectionManager.get_conn` to ensure more stability
-- by :user:`cyraxjoe`.

.. scm-version-title:: v8.4.4

- :issue:`304` via :pr:`309`: Refactored :py:class:`~\
Expand Down
21 changes: 12 additions & 9 deletions cheroot/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ def expire(self):
# find any connections still registered with the selector
# that have not been active recently enough.
threshold = time.time() - self.server.timeout
timed_out_connections = (
timed_out_connections = [
(sock_fd, conn)
for _, (_, sock_fd, _, conn) in self._selector.get_map().items()
for _, (_, sock_fd, _, conn)
in self._selector.get_map().items()
if conn != self.server and conn.last_used < threshold
)
]
for sock_fd, conn in timed_out_connections:
self._selector.unregister(sock_fd)
conn.close()
Expand Down Expand Up @@ -137,6 +138,7 @@ def get_conn(self): # noqa: C901 # FIXME
]
except OSError:
# Mark any connection which no longer appears valid
invalid_entries = []
for _, key in self._selector.get_map().items():
# If the server socket is invalid, we'll just ignore it and
# wait to be shutdown.
Expand All @@ -146,10 +148,11 @@ def get_conn(self): # noqa: C901 # FIXME
try:
os.fstat(key.fd)
except OSError:
# Socket is invalid, close the connection
self._selector.unregister(key.fd)
conn = key.data
conn.close()
invalid_entries.append((key.fd, key.data))

for sock_fd, conn in invalid_entries:
self._selector.unregister(sock_fd)
conn.close()

# Wait for the next tick to occur.
return None
Expand Down Expand Up @@ -273,8 +276,8 @@ def close(self):
self._selector.close()

@property
def _num_connections(self): # noqa: D401
"""The current number of connections.
def _num_connections(self):
"""Return the current number of connections.
Includes any in the readable list or registered with the selector,
minus one for the server socket, which is always registered
Expand Down
24 changes: 19 additions & 5 deletions cheroot/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def send_headers(self): # noqa: C901 # FIXME
# Override the decision to not close the connection if the connection
# manager doesn't have space for it.
if not self.close_connection:
can_keep = self.server.connections.can_add_keepalive_connection
can_keep = self.server.can_add_keepalive_connection
self.close_connection = not can_keep

if b'connection' not in hkeys:
Expand Down Expand Up @@ -1781,7 +1781,8 @@ def prepare(self): # noqa: C901 # FIXME
self.socket.settimeout(1)
self.socket.listen(self.request_queue_size)

self.connections = connections.ConnectionManager(self)
# must not be accessed once stop() has been called
self._connections = connections.ConnectionManager(self)

# Create worker threads
self.requests.start()
Expand Down Expand Up @@ -1829,6 +1830,19 @@ def _run_in_thread(self):
finally:
self.stop()

@property
def can_add_keepalive_connection(self):
"""Flag whether it is allowed to add a new keep-alive connection."""
return self.ready and self._connections.can_add_keepalive_connection

def put_conn(self, conn):
"""Put an idle connection back into the ConnectionManager."""
if self.ready:
self._connections.put(conn)
else:
# server is shutting down, just close it
conn.close()

def error_log(self, msg='', level=20, traceback=False):
"""Write error message to log.
Expand Down Expand Up @@ -2021,15 +2035,15 @@ def resolve_real_bind_addr(socket_):

def tick(self):
"""Accept a new connection and put it on the Queue."""
conn = self.connections.get_conn()
conn = self._connections.get_conn()
if conn:
try:
self.requests.put(conn)
except queue.Full:
# Just drop the conn. TODO: write 503 back?
conn.close()

self.connections.expire()
self._connections.expire()

@property
def interrupt(self):
Expand Down Expand Up @@ -2097,7 +2111,7 @@ def stop(self): # noqa: C901 # FIXME
sock.close()
self.socket = None

self.connections.close()
self._connections.close()
self.requests.stop(self.shutdown_timeout)


Expand Down
149 changes: 148 additions & 1 deletion cheroot/test/test_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import socket
import time
import logging
import traceback as traceback_
from collections import namedtuple

from six.moves import range, http_client, urllib

Expand All @@ -14,6 +17,7 @@

from cheroot.test import helper, webtest
from cheroot._compat import IS_PYPY
import cheroot.server


timeout = 1
Expand Down Expand Up @@ -107,8 +111,32 @@ def _munge(string):
}


class ErrorLogMonitor:
"""Mock class to access the server error_log calls made by the server."""

ErrorLogCall = namedtuple('ErrorLogCall', ['msg', 'level', 'traceback'])

def __init__(self):
"""Initialize the server error log monitor/interceptor.
If you need to ignore a particular error message use the property
``ignored_msgs` by appending to the list the expected error messages.
"""
self.calls = []
# to be used the the teardown validation
self.ignored_msgs = []

def __call__(self, msg='', level=logging.INFO, traceback=False):
"""Intercept the call to the server error_log method."""
if traceback:
tblines = traceback_.format_exc()
else:
tblines = ''
self.calls.append(ErrorLogMonitor.ErrorLogCall(msg, level, tblines))


@pytest.fixture
def testing_server(wsgi_server_client):
def raw_testing_server(wsgi_server_client):
"""Attach a WSGI app to the given server and preconfigure it."""
app = Controller()

Expand All @@ -121,9 +149,36 @@ def _timeout(req, resp):
wsgi_server.timeout = timeout
wsgi_server.server_client = wsgi_server_client
wsgi_server.keep_alive_conn_limit = 2

return wsgi_server


@pytest.fixture
def testing_server(raw_testing_server, monkeypatch):
"""Modify the "raw" base server to monitor the error_log messages.
If you need to ignore a particular error message use the property
``testing_server.error_log.ignored_msgs`` by appending to the list
the expected error messages.
"""
# patch the error_log calls of the server instance
monkeypatch.setattr(raw_testing_server, 'error_log', ErrorLogMonitor())

yield raw_testing_server

# Teardown verification, in case that the server logged an
# error that wasn't notified to the client or we just made a mistake.
for c in raw_testing_server.error_log.calls:
if c.level <= logging.WARNING:
continue

assert c.msg in raw_testing_server.error_log.ignored_msgs, (
'Found error in the error log: '
"message = '{c.msg}', level = '{c.level}'\n"
'{c.traceback}'.format(**locals()),
)


@pytest.fixture
def test_client(testing_server):
"""Get and return a test client out of the given server."""
Expand Down Expand Up @@ -951,6 +1006,13 @@ def test_Content_Length_out(

conn.close()

# the server logs the exception that we had verified from the
# client perspective. Tell the error_log verification that
# it can ignore that message.
test_client.server_instance.error_log.ignored_msgs.append(
"ValueError('Response body exceeds the declared Content-Length.')",
)


@pytest.mark.xfail(
reason='Sometimes this test fails due to low timeout. '
Expand Down Expand Up @@ -1000,3 +1062,88 @@ def test_No_CRLF(test_client, invalid_terminator):
expected_resp_body = b'HTTP requires CRLF terminators'
assert actual_resp_body == expected_resp_body
conn.close()


class FaultySelect:
"""Mock class to insert errors in the selector.select method."""

def __init__(self, original_select):
"""Initilize helper class to wrap the selector.select method."""
self.original_select = original_select
self.request_served = False
self.os_error_triggered = False

def __call__(self, timeout):
"""Intercept the calls to selector.select."""
if self.request_served:
self.os_error_triggered = True
raise OSError('Error while selecting the client socket.')

return self.original_select(timeout)


class FaultyGetMap:
"""Mock class to insert errors in the selector.get_map method."""

def __init__(self, original_get_map):
"""Initilize helper class to wrap the selector.get_map method."""
self.original_get_map = original_get_map
self.sabotage_conn = False
self.socket_closed = False

def __call__(self):
"""Intercept the calls to selector.get_map."""
sabotage_targets = (
conn for _, (*_, conn) in self.original_get_map().items()
if isinstance(conn, cheroot.server.HTTPConnection)
) if self.sabotage_conn else ()

for conn in sabotage_targets:
# close the socket to cause OSError
conn.close()
self.socket_closed = True

return self.original_get_map()


def test_invalid_selected_connection(test_client, monkeypatch):
"""Test the error handling segment of HTTP connection selection.
See :py:meth:`cheroot.connections.ConnectionManager.get_conn`.
"""
# patch the select method
faux_select = FaultySelect(
test_client.server_instance._connections._selector.select,
)
monkeypatch.setattr(
test_client.server_instance._connections._selector,
'select',
faux_select,
)

# patch the get_map method
faux_get_map = FaultyGetMap(
test_client.server_instance._connections._selector.get_map,
)

monkeypatch.setattr(
test_client.server_instance._connections._selector,
'get_map',
faux_get_map,
)

# request a page with connection keep-alive to make sure
# we'll have a connection to be modified.
resp_status, resp_headers, resp_body = test_client.request(
'/page1', headers=[('Connection', 'Keep-Alive')],
)

assert resp_status == '200 OK'
# trigger the internal errors
faux_get_map.sabotage_conn = faux_select.request_served = True
# give time to make sure the error gets handled
time.sleep(0.2)
assert faux_select.os_error_triggered
assert faux_get_map.socket_closed
# any error in the error handling should be catched by the
# teardown verification for the error_log
2 changes: 1 addition & 1 deletion cheroot/workers/threadpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def run(self):
keep_conn_open = conn.communicate()
finally:
if keep_conn_open:
self.server.connections.put(conn)
self.server.put_conn(conn)
else:
conn.close()
if is_stats_enabled:
Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
# Ref: https://stackoverflow.com/a/30624034/595220
nitpick_ignore = [
('py:class', 'cheroot.connections.ConnectionManager'),
('py:meth', 'cheroot.connections.ConnectionManager.get_conn'),

('py:const', 'socket.SO_PEERCRED'),
('py:class', '_pyio.BufferedWriter'),
('py:class', '_pyio.BufferedReader'),
Expand Down

0 comments on commit 7644616

Please sign in to comment.