Skip to content

Commit

Permalink
Fixed byte operations on stream decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuhiko committed May 27, 2013
1 parent 8ab9060 commit c9f9fc2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 38 deletions.
6 changes: 5 additions & 1 deletion werkzeug/testsuite/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,18 @@ def test_streamed_url_decoding(self):
self.assert_strict_equal(next(gen), ('c', item2))
self.assert_raises(StopIteration, lambda: next(gen))

def test_stream_decoding_string_fails(self):
self.assert_raises(TypeError, urls.url_decode_stream, 'testing')

def test_url_encoding(self):
self.assert_strict_equal(urls.url_encode({'foo': 'bar 45'}), 'foo=bar+45')
d = {'foo': 1, 'bar': 23, 'blah': u'Hänsel'}
self.assert_strict_equal(urls.url_encode(d, sort=True), 'bar=23&blah=H%C3%A4nsel&foo=1')
self.assert_strict_equal(urls.url_encode(d, sort=True, separator=u';'), 'bar=23;blah=H%C3%A4nsel;foo=1')

def test_sorted_url_encode(self):
self.assert_strict_equal(urls.url_encode({u"a": 42, u"b": 23, 1: 1, 2: 2}, sort=True, key=lambda i: text_type(i[0])), '1=1&2=2&a=42&b=23')
self.assert_strict_equal(urls.url_encode({u"a": 42, u"b": 23, 1: 1, 2: 2},
sort=True, key=lambda i: text_type(i[0])), '1=1&2=2&a=42&b=23')
self.assert_strict_equal(urls.url_encode({u'A': 1, u'a': 2, u'B': 3, 'b': 4}, sort=True,
key=lambda x: x[0].lower() + x[0]), 'A=1&a=2&B=3&b=4')

Expand Down
98 changes: 61 additions & 37 deletions werkzeug/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
import sys
import posixpath
import mimetypes
from itertools import chain, repeat
from itertools import chain
from zlib import adler32
from time import time, mktime
from datetime import datetime
from functools import partial

from werkzeug import urls
from werkzeug._compat import string_join, iteritems, text_type, string_types, \
implements_iterator
implements_iterator, make_literal_wrapper, to_unicode, to_bytes
from werkzeug._internal import _patch_wrapper
from werkzeug.http import is_resource_modified, http_date

Expand Down Expand Up @@ -650,12 +650,24 @@ def make_limited_stream(stream, limit):
return stream


def make_chunk_iter_func(stream, limit, buffer_size):
def _make_chunk_iter(stream, limit, buffer_size):
"""Helper for the line and chunk iter functions."""
if hasattr(stream, 'read'):
return partial(make_limited_stream(stream, limit).read, buffer_size)
iterator = iter(chain(stream, repeat('')))
return partial(next, iterator)
if isinstance(stream, (bytes, bytearray, text_type)):
raise TypeError('Passed a string or byte object instead of '
'true iterator or stream.')
if not hasattr(stream, 'read'):
def generate():
for item in stream:
if item:
yield item
else:
def generate(_read=make_limited_stream(stream, limit).read):
while 1:
item = _read(buffer_size)
if not item:
break
yield item
return generate()


def make_line_iter(stream, limit=None, buffer_size=10 * 1024):
Expand Down Expand Up @@ -683,35 +695,44 @@ def make_line_iter(stream, limit=None, buffer_size=10 * 1024):
is a :class:`LimitedStream`.
:param buffer_size: The optional buffer size.
"""
_iter = _make_chunk_iter(stream, limit, buffer_size)

first_item = next(_iter, '')
if not first_item:
return

s = make_literal_wrapper(first_item)
empty = s('')
cr = s('\r')
lf = s('\n')
crlf = s('\r\n')

_iter = chain((first_item,), _iter)

def _iter_basic_lines():
_read = make_chunk_iter_func(stream, limit, buffer_size)
_join = empty.join
buffer = []
while 1:
new_data = _read()
new_data = next(_iter, '')
if not new_data:
break
new_buf = []
for item in chain(buffer, new_data.splitlines(True)):
new_buf.append(item)
if isinstance(item, text_type):
if item and item[-1:] in '\r\n':
yield ''.join(new_buf)
new_buf = []
else:
if item and item[-1:] in b'\r\n':
yield b''.join(new_buf)
new_buf = []
if item and item[-1:] in crlf:
yield _join(new_buf)
new_buf = []
buffer = new_buf
if buffer:
yield string_join(buffer)
yield _join(buffer)

# This hackery is necessary to merge 'foo\r' and '\n' into one item
# of 'foo\r\n' if we were unlucky and we hit a chunk boundary.
previous = ''
previous = empty
for item in _iter_basic_lines():
if item in ['\n', b'\n'] and previous[-1:] in ['\r', b'\r']:
if item == lf and previous[-1:] == cr:
previous += item
item = ''
item = empty
if previous:
yield previous
previous = item
Expand All @@ -737,29 +758,32 @@ def make_chunk_iter(stream, separator, limit=None, buffer_size=10 * 1024):
is otherwise already limited).
:param buffer_size: The optional buffer size.
"""
_read = make_chunk_iter_func(stream, limit, buffer_size)
separator_pattern = r'(%s)' % re.escape(separator)
_string_split = re.compile(separator_pattern).split
_bytes_split = re.compile(separator_pattern.encode('ascii')).split
if isinstance(separator, text_type):
string_separator = separator
bytes_separator = separator.encode('ascii')
_iter = _make_chunk_iter(stream, limit, buffer_size)

first_item = next(_iter, '')
if not first_item:
return

_iter = chain((first_item,), _iter)
if isinstance(first_item, text_type):
separator = to_unicode(separator)
_split = re.compile(r'(%s)' % re.escape(separator)).split
_join = u''.join
else:
bytes_separator = separator
string_separator = separator.decode('ascii')
separator = to_bytes(separator)
_split = re.compile(b'(' + re.escape(separator) + b')').split
_join = b''.join

buffer = []
while 1:
new_data = _read()
new_data = next(_iter, '')
if not new_data:
break
if isinstance(new_data, text_type):
chunks = _string_split(new_data)
else:
chunks = _bytes_split(new_data)
chunks = _split(new_data)
new_buf = []
for item in chain(buffer, chunks):
if item in [string_separator, bytes_separator]:
yield string_join(new_buf)
if item == separator:
yield _join(new_buf)
new_buf = []
else:
new_buf.append(item)
Expand Down

0 comments on commit c9f9fc2

Please sign in to comment.