Skip to content

Commit

Permalink
add validate_pointers for tuples and arrays, add tests and more
Browse files Browse the repository at this point in the history
descriptive comments
  • Loading branch information
pacrob committed Mar 1, 2024
1 parent 75aab7c commit 82c1ad3
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 7 deletions.
84 changes: 80 additions & 4 deletions eth_abi/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from eth_abi.exceptions import (
InsufficientDataBytes,
InvalidPointer,
NonEmptyPaddingBytes,
)
from eth_abi.utils.numeric import (
Expand Down Expand Up @@ -78,13 +79,13 @@ def __init__(self, *args, **kwargs):
self._frames = []
self._total_offset = 0

def seek_in_frame(self, pos, *args, **kwargs):
def seek_in_frame(self, pos: int, *args: Any, **kwargs: Any) -> None:
"""
Seeks relative to the total offset of the current contextual frames.
"""
self.seek(self._total_offset + pos, *args, **kwargs)

def push_frame(self, offset):
def push_frame(self, offset: int) -> None:
"""
Pushes a new contextual frame onto the stack with the given offset and a
return position at the current cursor position then seeks to the new
Expand Down Expand Up @@ -131,6 +132,13 @@ def __call__(self, stream: ContextFramesBytesIO) -> Any:


class HeadTailDecoder(BaseDecoder):
"""
Decoder for a dynamic element of a dynamic container (a dynamic array, or a sized
array or tuple that contains dynamic elements). A dynamic element consists of a
pointer, aka offset, which is located in the head section of the encoded container,
and the actual value, which is located in the tail section of the encoding.
"""

is_dynamic = True

tail_decoder = None
Expand All @@ -141,13 +149,18 @@ def validate(self):
if self.tail_decoder is None:
raise ValueError("No `tail_decoder` set")

def decode(self, stream):
def decode(self, stream: ContextFramesBytesIO) -> Any:
# Decode the offset and move the stream cursor forward 32 bytes
start_pos = decode_uint_256(stream)

# Jump ahead to the start of the value
stream.push_frame(start_pos)

# assertion check for mypy
if self.tail_decoder is None:
raise AssertionError("`tail_decoder` is None")
# Decode the value
value = self.tail_decoder(stream)
# Return the cursor
stream.pop_frame()

return value
Expand All @@ -172,8 +185,43 @@ def validate(self):
if self.decoders is None:
raise ValueError("No `decoders` set")

def validate_pointers(self, stream: ContextFramesBytesIO) -> None:
"""
Verify that all pointers point to a valid location in the stream.
"""
current_location = stream.tell()
len_of_head = sum(
decoder.array_size if hasattr(decoder, "array_size") else 1
for decoder in self.decoders
)
end_of_offsets = current_location + 32 * len_of_head
total_stream_length = len(stream.getbuffer())
for decoder in self.decoders:
if isinstance(decoder, HeadTailDecoder):
# the next 32 bytes are a pointer
offset = decode_uint_256(stream)
indicated_idx = current_location + offset
if (
indicated_idx < end_of_offsets
or indicated_idx >= total_stream_length
):
# the pointer is indicating its data is located either within the
# offsets section of the stream or beyond the end of the stream,
# both of which are invalid
raise InvalidPointer(
"Invalid pointer in tuple at location "
f"{stream.tell() - 32} in payload"
)
else:
# the next 32 bytes are not a pointer, so progress the stream per
# the decoder
decoder(stream)
# return the stream to its original location for actual decoding
stream.seek(current_location)

@to_tuple # type: ignore[misc] # untyped decorator
def decode(self, stream: ContextFramesBytesIO) -> Generator[Any, None, None]:
self.validate_pointers(stream)
for decoder in self.decoders:
yield decoder(stream)

Expand Down Expand Up @@ -248,6 +296,30 @@ def from_type_str(cls, abi_type, registry):
# If array dimension is dynamic
return DynamicArrayDecoder(item_decoder=item_decoder)

def validate_pointers(self, stream: ContextFramesBytesIO, array_size: int) -> None:
"""
Verify that all pointers point to a valid location in the stream.
"""
if isinstance(self.item_decoder, HeadTailDecoder):
current_location = stream.tell()
end_of_offsets = current_location + 32 * array_size
total_stream_length = len(stream.getbuffer())
for _ in range(array_size):
offset = decode_uint_256(stream)
indicated_idx = current_location + offset
if (
indicated_idx < end_of_offsets
or indicated_idx >= total_stream_length
):
# the pointer is indicating its data is located either within the
# offsets section of the stream or beyond the end of the stream,
# both of which are invalid
raise InvalidPointer(
"Invalid pointer in array at location "
f"{stream.tell() - 32} in payload"
)
stream.seek(current_location)


class SizedArrayDecoder(BaseArrayDecoder):
array_size = None
Expand All @@ -261,6 +333,8 @@ def __init__(self, **kwargs):
def decode(self, stream):
if self.item_decoder is None:
raise AssertionError("`item_decoder` is None")

self.validate_pointers(stream, self.array_size)
for _ in range(self.array_size):
yield self.item_decoder(stream)

Expand All @@ -275,6 +349,8 @@ def decode(self, stream):
stream.push_frame(32)
if self.item_decoder is None:
raise AssertionError("`item_decoder` is None")

self.validate_pointers(stream, array_size)
for _ in range(array_size):
yield self.item_decoder(stream)
stream.pop_frame()
Expand Down
9 changes: 7 additions & 2 deletions eth_abi/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ class DecodingError(Exception):

class InsufficientDataBytes(DecodingError):
"""
Raised when there are insufficient data to decode a value for a given ABI
type.
Raised when there are insufficient data to decode a value for a given ABI type.
"""


Expand All @@ -58,6 +57,12 @@ class NonEmptyPaddingBytes(DecodingError):
"""


class InvalidPointer(DecodingError):
"""
Raised when the pointer to a value in the ABI encoding is invalid.
"""


class ParseError(parsimonious.ParseError): # type: ignore[misc] # subclasses Any
"""
Raised when an ABI type string cannot be parsed.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/226.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
During decoding, verify all pointers in arrays and tuples point to a valid location in the payload
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
],
"test": [
"pytest>=7.0.0",
"pytest-timeout>=2.0.0",
"pytest-xdist>=2.4.0",
"pytest-pythonpath>=0.7.1",
"eth-hash[pycryptodome]",
Expand Down
53 changes: 52 additions & 1 deletion tests/abi/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from eth_abi.exceptions import (
InsufficientDataBytes,
InvalidPointer,
)
from eth_abi.grammar import (
parse,
Expand Down Expand Up @@ -68,7 +69,6 @@ def test_abi_decode_for_single_dynamic_types(
)

(actual,) = decode([type_str], abi_encoding, strict=strict)

assert actual == expected


Expand Down Expand Up @@ -198,3 +198,54 @@ def test_abi_decode_with_shorter_data_than_32_bytes(types, hex_data, expected):
# without the flag set (i.e. assert the default behavior is always ``strict=True``).
with pytest.raises(InsufficientDataBytes):
decode(types, bytes.fromhex(hex_data))


@pytest.mark.parametrize(
"typestring,malformed_payload",
(
(
["uint256[][][][][][][][][][]"],
("0" * 62 + "20") * 10 + "00" * 2056,
),
(
["uint256[][][][][][][][][][][][]"],
"0" * 62 + "20" + "0" * 62 + "a0" + ("0" * 62 + "20") * 9 + "00" * 1024,
),
(
["uint8[]", "uint[2]", "uint8[]"],
"00000000000000000000000000000000000000000000000000000000000000600000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", # noqa: E501
),
(
["(uint8[2],uint8[])"],
"00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000", # noqa: E501
),
(
["(uint8[2],uint8[])"],
"00000000000000000000000000000000000000000000000000000000000000f00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000", # noqa: E501
),
(
["uint8[]"],
"00000000000000000000000000000000000000000000000000000000000000f00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000200000000000000000000000000000000000000000000000000000000000000000", # noqa: E501
),
(
["(uint8[],uint8[8],uint8[])"],
"0000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000016000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", # noqa: E501
),
),
ids=(
"nested array with all pointers equal to 0x20",
"nested array user example",
"separate dynamic, sized, dynamic arrays",
"tuple of sized array and dynamic array",
"pointer beyond end of data for tuple",
"pointer beyond end of data for array",
"tuple of arrays to check length of head section of tuple calcd correctly",
),
)
@pytest.mark.timeout(1)
def test_decode_nested_dynamic_array_with_invalid_pointer_fails_fast(
typestring, malformed_payload
):
malformed_payload_bytes = bytearray.fromhex(malformed_payload)
with pytest.raises(InvalidPointer, match=r"^Invalid pointer in"):
decode(typestring, malformed_payload_bytes)

0 comments on commit 82c1ad3

Please sign in to comment.