Skip to content

Commit

Permalink
Make gevent.pywsgi stop dealing with chunks when the connection is be…
Browse files Browse the repository at this point in the history
…ing upgraded.

Let the application have full control over input and output.

Fixes #1712.
  • Loading branch information
jamadden committed Dec 19, 2020
1 parent f54fa61 commit 8c49775
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 13 deletions.
7 changes: 7 additions & 0 deletions docs/changes/1712.bugfix
@@ -0,0 +1,7 @@
Make `gevent.pywsgi` trying to enforce the rules for reading chunked input or
``Content-Length`` terminated input when the connection is being
upgraded, for example to a websocket connection. Likewise, if the
protocol was switched by returning a ``101`` status, stop trying to
automatically chunk the responses.

Reported by Kavindu Santhusa.
40 changes: 28 additions & 12 deletions src/gevent/pywsgi.py
Expand Up @@ -415,6 +415,9 @@ def MessageClass(self, *args):
time_finish = 0 # time.time() when done handling request
headers_sent = False # Have we already sent headers?
response_use_chunked = False # Write with transfer-encoding chunked
# Was the connection upgraded? We shouldn't try to chunk writes in that
# case.
connection_upgraded = False
environ = None # Dict from self.get_environ
application = None # application callable from self.server.application
requestline = None # native str 'GET / HTTP/1.1'
Expand Down Expand Up @@ -486,6 +489,7 @@ def handle(self):
pass
self.__dict__.pop('socket', None)
self.__dict__.pop('rfile', None)
self.__dict__.pop('wsgi_input', None)

def _check_http_version(self):
version_str = self.request_version
Expand Down Expand Up @@ -697,10 +701,19 @@ def handle_one_request(self):

return True # read more requests

def _connection_upgrade_requested(self):
if self.headers.get('Connection', '').lower() == 'upgrade':
return True
if self.headers.get('Upgrade', '').lower() == 'websocket':
return True
return False

def finalize_headers(self):
if self.provided_date is None:
self.response_headers.append((b'Date', format_date_time(time.time())))

self.connection_upgraded = self.code == 101

if self.code not in (304, 204):
# the reply will include message-body; make sure we have either Content-Length or chunked
if self.provided_content_length is None:
Expand All @@ -711,8 +724,11 @@ def finalize_headers(self):
total_len_str = total_len_str.encode("latin-1")
self.response_headers.append((b'Content-Length', total_len_str))
else:
if self.request_version != 'HTTP/1.0':
self.response_use_chunked = True
self.response_use_chunked = (
not self.connection_upgraded
and self.request_version != 'HTTP/1.0'
)
if self.response_use_chunked:
self.response_headers.append((b'Transfer-Encoding', b'chunked'))

def _sendall(self, data):
Expand Down Expand Up @@ -975,6 +991,7 @@ def handle_one_response(self):

self.result = None
self.response_use_chunked = False
self.connection_upgraded = False
self.response_length = 0

try:
Expand Down Expand Up @@ -1103,10 +1120,7 @@ def get_environ(self):
# See https://github.com/gevent/gevent/issues/1667 for discussion.
env['SCRIPT_NAME'] = ''

if '?' in self.path:
path, query = self.path.split('?', 1)
else:
path, query = self.path, ''
path, query = self.path.split('?', 1) if '?' in self.path else (self.path, '')
# Note that self.path contains the original str object; if it contains
# encoded escapes, it will NOT match PATH_INFO.
env['PATH_INFO'] = unquote_latin1(path)
Expand Down Expand Up @@ -1134,18 +1148,20 @@ def get_environ(self):
else:
env[key] = value

if env.get('HTTP_EXPECT') == '100-continue':
sock = self.socket
else:
sock = None
sock = self.socket if env.get('HTTP_EXPECT') == '100-continue' else None

chunked = env.get('HTTP_TRANSFER_ENCODING', '').lower() == 'chunked'
# Input refuses to read if the data isn't chunked, and there is no content_length
# provided. For 'Upgrade: Websocket' requests, neither of those things is true.
handling_reads = not self._connection_upgrade_requested()

self.wsgi_input = Input(self.rfile, self.content_length, socket=sock, chunked_input=chunked)
env['wsgi.input'] = self.wsgi_input

env['wsgi.input'] = self.wsgi_input if handling_reads else self.rfile
# This is a non-standard flag indicating that our input stream is
# self-terminated (returns EOF when consumed).
# See https://github.com/gevent/gevent/issues/1308
env['wsgi.input_terminated'] = True
env['wsgi.input_terminated'] = handling_reads
return env


Expand Down
40 changes: 39 additions & 1 deletion src/gevent/tests/test__pywsgi.py
Expand Up @@ -432,18 +432,35 @@ class TestNoChunks(CommonTestMixin, TestCase):
# when returning a list of strings a shortcut is employed by the server:
# it calculates the content-length and joins all the chunks before sending
validator = None
last_environ = None

def _check_environ(self, input_terminated=True):
if input_terminated:
self.assertTrue(self.last_environ.get('wsgi.input_terminated'))
else:
self.assertFalse(self.last_environ['wsgi.input_terminated'])

def application(self, env, start_response):
self.assertTrue(env.get('wsgi.input_terminated'))
self.last_environ = env
path = env['PATH_INFO']
if path == '/':
start_response('200 OK', [('Content-Type', 'text/plain')])
return [b'hello ', b'world']
if path == '/websocket':
write = start_response('101 Switching Protocols',
[('Content-Type', 'text/plain'),
# Con:close is to make our simple client
# happy; otherwise it wants to read data from the
# body thot's being kept open.
('Connection', 'close')])
write(b'') # Trigger finalizing the headers now.
return [b'upgrading to', b'websocket']
start_response('404 Not Found', [('Content-Type', 'text/plain')])
return [b'not ', b'found']

def test_basic(self):
response, dne_response = super(TestNoChunks, self).test_basic()
self._check_environ()
self.assertFalse(response.chunks)
response.assertHeader('Content-Length', '11')
if dne_response is not None:
Expand All @@ -455,8 +472,28 @@ def test_dne(self):
fd.write(self.format_request(path='/notexist'))
response = read_http(fd, code=404, reason='Not Found', body='not found')
self.assertFalse(response.chunks)
self._check_environ()
response.assertHeader('Content-Length', '9')

class TestConnectionUpgrades(TestNoChunks):

def test_connection_upgrade(self):
with self.makefile() as fd:
fd.write(self.format_request(path='/websocket', Connection='upgrade'))
response = read_http(fd, code=101)

self._check_environ(input_terminated=False)
self.assertFalse(response.chunks)

def test_upgrade_websocket(self):
with self.makefile() as fd:
fd.write(self.format_request(path='/websocket', Upgrade='websocket'))
response = read_http(fd, code=101)

self._check_environ(input_terminated=False)
self.assertFalse(response.chunks)


class TestNoChunks10(TestNoChunks):
HTTP_CLIENT_VERSION = '1.0'
PIPELINE_NOT_SUPPORTED_EXS = (ConnectionClosed,)
Expand All @@ -475,6 +512,7 @@ class TestExplicitContentLength(TestNoChunks): # pylint:disable=too-many-ancesto
# server - it caculates the content-length

def application(self, env, start_response):
self.last_environ = env
self.assertTrue(env.get('wsgi.input_terminated'))
path = env['PATH_INFO']
if path == '/':
Expand Down

0 comments on commit 8c49775

Please sign in to comment.