Skip to content

Commit

Permalink
Never try to decode bytes as UTF-8
Browse files Browse the repository at this point in the history
There're two issues in the existing code.

First, this heuristic often fails on short random strings,
making their type unreliable. For example, "t" (transaction ID) in
DHT protocol is just two bytes.

Second, the code assumed that they keys of dictionaries are always
UTF-8 encoded. Although that's true in many BitTorrent-related protocols,
bencoding doesn't mandate it, and exceptions do in fact exist:
see #6.

Also, change the code to consistently return bytes/str in all cases,
leaving support for passing encoding str/unicode strings, which are
transparently encoded to UTF-8.
  • Loading branch information
WGH- committed May 9, 2020
1 parent e8290df commit 44230fa
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 107 deletions.
41 changes: 9 additions & 32 deletions bencode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,9 @@ def decode_int(x, f):
return n, newf + 1


def decode_string(x, f, try_decode_utf8=True, force_decode_utf8=False):
def decode_string(x, f):
# type: (bytes, int, bool, bool) -> Tuple[bytes, int]
"""Decode torrent bencoded 'string' in x starting at f.
An attempt is made to convert the string to a python string from utf-8.
However, both string and non-string binary data is intermixed in the
torrent bencoding standard. So we have to guess whether the byte
sequence is a string or just binary data. We make this guess by trying
to decode (from utf-8), and if that fails, assuming it is binary data.
There are some instances where the data SHOULD be a string though.
You can check enforce this by setting force_decode_utf8 to True. If the
decoding from utf-8 fails, an UnidcodeDecodeError is raised. Similarly,
if you know it should not be a string, you can skip the decoding
attempt by setting try_decode_utf8=False.
"""
colon = x.index(b':', f)
n = int(x[f:colon])
Expand All @@ -84,13 +73,6 @@ def decode_string(x, f, try_decode_utf8=True, force_decode_utf8=False):
colon += 1
s = x[colon:colon + n]

if try_decode_utf8:
try:
return s.decode('utf-8'), colon + n
except UnicodeDecodeError:
if force_decode_utf8:
raise

return bytes(s), colon + n


Expand Down Expand Up @@ -135,7 +117,7 @@ def decode_dict(x, f, force_sort=True):
r, f = OrderedDict(), f + 1

while x[f:f + 1] != b'e':
k, f = decode_string(x, f, force_decode_utf8=True)
k, f = decode_string(x, f)
r[k], f = decode_func[x[f:f + 1]](x, f)

if force_sort:
Expand Down Expand Up @@ -219,13 +201,7 @@ def encode_bytes(x, r):

def encode_string(x, r):
# type: (str, Deque[bytes]) -> None
try:
s = x.encode('utf-8')
except UnicodeDecodeError:
encode_bytes(x, r)
return

r.extend((str(len(s)).encode('utf-8'), b':', s))
return encode_bytes(x.encode("UTF-8"), r)


def encode_list(x, r):
Expand All @@ -241,12 +217,13 @@ def encode_list(x, r):
def encode_dict(x, r):
# type: (Dict, Deque[bytes]) -> None
r.append(b'd')
ilist = list(x.items())
ilist.sort()

# force all keys to bytes, because str and bytes are incomparable
ilist = [(k if type(k) == type(b"") else k.encode("UTF-8"), v) for k, v in x.items()]
ilist.sort(key=lambda kv: kv[0])

for k, v in ilist:
k = k.encode('utf-8')
r.extend((str(len(k)).encode('utf-8'), b':', k))
encode_func[type(k)](k, r)
encode_func[type(v)](v, r)

r.append(b'e')
Expand All @@ -263,7 +240,7 @@ def encode_dict(x, r):
encode_func[IntType] = encode_int
encode_func[ListType] = encode_list
encode_func[LongType] = encode_int
encode_func[StringType] = encode_string
encode_func[StringType] = encode_bytes
encode_func[TupleType] = encode_list
encode_func[UnicodeType] = encode_string

Expand Down
77 changes: 14 additions & 63 deletions tests/bencode_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,66 +15,43 @@


VALUES = [
(0, 'i0e'),
(1, 'i1e'),
(10, 'i10e'),
(42, 'i42e'),
(-42, 'i-42e'),
(True, 'i1e'),
(False, 'i0e'),
('spam', '4:spam'),
('parrot sketch', '13:parrot sketch'),
(['parrot sketch', 42], 'l13:parrot sketchi42ee'),
({'foo': 42, 'bar': 'spam'}, 'd3:bar4:spam3:fooi42ee')
(0, b'i0e'),
(1, b'i1e'),
(10, b'i10e'),
(42, b'i42e'),
(-42, b'i-42e'),
(True, b'i1e'),
(False, b'i0e'),
(b'spam', b'4:spam'),
(b'parrot sketch', b'13:parrot sketch'),
([b'parrot sketch', 42], b'l13:parrot sketchi42ee'),
({b'foo': 42, b'bar': b'spam'}, b'd3:bar4:spam3:fooi42ee')
]

if OrderedDict is not None:
VALUES.append((OrderedDict((
('bar', 'spam'),
('foo', 42)
)), 'd3:bar4:spam3:fooi42ee'))
(b'bar', b'spam'),
(b'foo', 42)
)), b'd3:bar4:spam3:fooi42ee'))


@pytest.mark.skipif(sys.version_info[0] < 3, reason="Requires: Python 3+")
def test_encode():
"""Encode should give known result with known input."""
for plain, encoded in VALUES:
assert encoded.encode('utf-8') == bencode(plain)


@pytest.mark.skipif(sys.version_info[0] != 2, reason="Requires: Python 2")
def test_encode_py2():
"""Encode should give known result with known input."""
for plain, encoded in VALUES:
assert encoded == bencode(plain)


@pytest.mark.skipif(sys.version_info[0] < 3, reason="Requires: Python 3+")
def test_encode_bencached():
"""Ensure Bencached objects can be encoded."""
assert bencode([Bencached(bencode('test'))]) == b'l4:teste'


@pytest.mark.skipif(sys.version_info[0] != 2, reason="Requires: Python 2")
def test_encode_bencached_py2():
"""Ensure Bencached objects can be encoded."""
assert bencode([Bencached(bencode('test'))]) == 'l4:teste'


def test_encode_bytes():
"""Ensure bytes can be encoded."""
assert bencode(b'\x9c') == b'1:\x9c'


@pytest.mark.skipif(sys.version_info[0] < 3, reason="Requires: Python 3+")
def test_decode():
"""Decode should give known result with known input."""
for plain, encoded in VALUES:
assert plain == bdecode(encoded.encode('utf-8'))


@pytest.mark.skipif(sys.version_info[0] != 2, reason="Requires: Python 2")
def test_decode_py2():
"""Decode should give known result with known input."""
for plain, encoded in VALUES:
assert plain == bdecode(encoded)
Expand All @@ -85,15 +62,7 @@ def test_decode_bytes():
assert bdecode(b'1:\x9c') == b'\x9c'


@pytest.mark.skipif(sys.version_info[0] < 3, reason="Requires: Python 3+")
def test_encode_roundtrip():
"""Consecutive calls to decode and encode should deliver the original data again."""
for plain, encoded in VALUES:
assert encoded.encode('utf-8') == bencode(bdecode(encoded.encode('utf-8')))


@pytest.mark.skipif(sys.version_info[0] != 2, reason="Requires: Python 2")
def test_encode_roundtrip_py2():
"""Consecutive calls to decode and encode should deliver the original data again."""
for plain, encoded in VALUES:
assert encoded == bencode(bdecode(encoded))
Expand Down Expand Up @@ -142,33 +111,15 @@ def test_dictionary_sorted():
assert encoded.index(b'zoo') > encoded.index(b'bar')


@pytest.mark.skipif(sys.version_info[0] < 3, reason="Requires: Python 3+")
def test_dictionary_unicode():
"""Test the handling of unicode in dictionaries."""
encoded = bencode({u'foo': 42, 'bar': {u'sketch': u'parrot', 'foobar': 23}})

assert encoded == 'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'.encode('utf-8')


@pytest.mark.skipif(sys.version_info[0] != 2, reason="Requires: Python 2")
def test_dictionary_unicode_py2():
"""Test the handling of unicode in dictionaries."""
encoded = bencode({u'foo': 42, 'bar': {u'sketch': u'parrot', 'foobar': 23}})

assert encoded == 'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'


@pytest.mark.skipif(sys.version_info[0] < 3, reason="Requires: Python 3+")
def test_dictionary_nested():
"""Test the handling of nested dictionaries."""
encoded = bencode({'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}})

assert encoded == 'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'.encode('utf-8')


@pytest.mark.skipif(sys.version_info[0] != 2, reason="Requires: Python 2")
def test_dictionary_nested_py2():
"""Test the handling of nested dictionaries."""
encoded = bencode({'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}})

assert encoded == 'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'
24 changes: 12 additions & 12 deletions tests/file_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_read_file():
with open(os.path.join(FIXTURE_DIR, 'alpha'), 'rb') as fp:
data = bread(fp)

assert data == {'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}}
assert data == {b'foo': 42, b'bar': {b'sketch': b'parrot', b'foobar': 23}}


def test_read_path():
"""Test the reading of bencode paths."""
data = bread(os.path.join(FIXTURE_DIR, 'alpha'))

assert data == {'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}}
assert data == {b'foo': 42, b'bar': {b'sketch': b'parrot', b'foobar': 23}}


@pytest.mark.skipif(sys.version_info < (3, 4), reason="Requires: Python 3.4+")
Expand All @@ -39,30 +39,30 @@ def test_read_pathlib():

data = bread(Path(FIXTURE_DIR, 'alpha'))

assert data == {'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}}
assert data == {b'foo': 42, b'bar': {b'sketch': b'parrot', b'foobar': 23}}


def test_write_file():
"""Test the writing of bencode paths."""
with open(os.path.join(TEMP_DIR, 'beta'), 'wb') as fp:
bwrite(
{'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}},
{b'foo': 42, b'bar': {b'sketch': b'parrot', b'foobar': 23}},
fp
)

with open(os.path.join(TEMP_DIR, 'beta'), 'r') as fp:
assert fp.read() == 'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'
with open(os.path.join(TEMP_DIR, 'beta'), 'rb') as fp:
assert fp.read() == b'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'


def test_write_path():
"""Test the writing of bencode files."""
bwrite(
{'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}},
{b'foo': 42, b'bar': {b'sketch': b'parrot', b'foobar': 23}},
os.path.join(TEMP_DIR, 'beta')
)

with open(os.path.join(TEMP_DIR, 'beta'), 'r') as fp:
assert fp.read() == 'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'
with open(os.path.join(TEMP_DIR, 'beta'), 'rb') as fp:
assert fp.read() == b'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'


@pytest.mark.skipif(sys.version_info < (3, 4), reason="Requires: Python 3.4+")
Expand All @@ -71,9 +71,9 @@ def test_write_pathlib():
from pathlib import Path

bwrite(
{'foo': 42, 'bar': {'sketch': 'parrot', 'foobar': 23}},
{b'foo': 42, b'bar': {b'sketch': b'parrot', b'foobar': 23}},
Path(TEMP_DIR, 'beta')
)

with open(os.path.join(TEMP_DIR, 'beta'), 'r') as fp:
assert fp.read() == 'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'
with open(os.path.join(TEMP_DIR, 'beta'), 'rb') as fp:
assert fp.read() == b'd3:bard6:foobari23e6:sketch6:parrote3:fooi42ee'

0 comments on commit 44230fa

Please sign in to comment.