Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

add allow_overwrite + unit test; closes #8

  • Loading branch information...
commit 287a1e58002e2bfae4f49821b1c427b6437b1c0b 1 parent 3f624ce
@tomerfiliba tomerfiliba authored
View
1  construct/__init__.py
@@ -73,4 +73,5 @@
'TunnelAdapter', 'UBInt16', 'UBInt32', 'UBInt64', 'UBInt8', 'ULInt16',
'ULInt32', 'ULInt64', 'ULInt8', 'UNInt16', 'UNInt32', 'UNInt64', 'UNInt8',
'Union', 'ValidationError', 'Validator', 'Value', "Magic", "this",
+ "OverwriteError"
]
View
7 construct/core.py
@@ -26,6 +26,8 @@ class SelectError(ConstructError):
pass
class TerminatorError(ConstructError):
pass
+class OverwriteError(ValueError):
+ pass
#===============================================================================
# abstract constructs
@@ -627,9 +629,10 @@ class Struct(Construct):
UBInt8("third_element"),
)
"""
- __slots__ = ["subcons", "nested"]
+ __slots__ = ["subcons", "nested", "allow_overwrite"]
def __init__(self, name, *subcons, **kw):
self.nested = kw.pop("nested", True)
+ self.allow_overwrite = kw.pop("allow_overwrite", False)
if kw:
raise TypeError("the only keyword argument accepted is 'nested'", kw)
Construct.__init__(self, name)
@@ -651,6 +654,8 @@ def _parse(self, stream, context):
else:
subobj = sc._parse(stream, context)
if sc.name is not None:
+ if sc.name in obj and not self.allow_overwrite:
+ raise OverwriteError("%r would be overwritten but allow_overwrite is False" % (sc.name,))
obj[sc.name] = subobj
context[sc.name] = subobj
return obj
View
2  construct/version.py
@@ -1,3 +1,3 @@
version = (2, 5, 0)
version_string = "2.5.0"
-release_date = "2012.12.22"
+release_date = "2013.01.12"
View
58 tests/test_overwrite.py
@@ -0,0 +1,58 @@
+import unittest
+from construct import Struct, Byte, Embedded, OverwriteError
+
+
+class TestOverwrite(unittest.TestCase):
+ def test_overwrite(self):
+ s = Struct("s",
+ Byte("a"),
+ Byte("a"),
+ allow_overwrite = True
+ )
+ self.assertEqual(s.parse("\x01\x02").a, 2)
+
+ s = Struct("s",
+ Byte("a"),
+ Embedded(Struct("b",
+ Byte("a"),
+ allow_overwrite = True
+ )),
+ )
+ self.assertEqual(s.parse("\x01\x02").a, 2)
+
+ s = Struct("s",
+ Embedded(Struct("b",
+ Byte("a"),
+ )),
+ Byte("a"),
+ allow_overwrite = True
+ )
+ self.assertEqual(s.parse("\x01\x02").a, 2)
+
+ def test_no_overwrite(self):
+ s = Struct("s",
+ Byte("a"),
+ Byte("a"),
+ )
+ self.assertRaises(OverwriteError, s.parse, "\x01\x02")
+
+ s = Struct("s",
+ Byte("a"),
+ Embedded(Struct("b",
+ Byte("a"),
+ )),
+ allow_overwrite = True
+ )
+ self.assertRaises(OverwriteError, s.parse, "\x01\x02")
+
+ s = Struct("s",
+ Embedded(Struct("b",
+ Byte("a"),
+ allow_overwrite = True
+ )),
+ Byte("a"),
+ )
+ self.assertRaises(OverwriteError, s.parse, "\x01\x02")
+
+if __name__ == "__main__":
+ unittest.main()
Please sign in to comment.
Something went wrong with that request. Please try again.