Skip to content

Commit

Permalink
Merge pull request #79 from riggs/master
Browse files Browse the repository at this point in the history
Bugfix & enhancements in FlagsEnum. Many unit tests.
  • Loading branch information
arekbulski committed Aug 26, 2016
2 parents 948b53c + f345191 commit 3136f4b
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 34 deletions.
13 changes: 9 additions & 4 deletions construct/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,19 @@ def __init__(self, subcon, flags):
self.flags = flags
def _encode(self, obj, context):
flags = 0
for name, value in self.flags.items():
if getattr(obj, name, False):
flags |= value
try:
for name, value in obj.items():
if value:
flags |= self.flags[name]
except AttributeError:
raise MappingError("not a mapping type: %r" % (obj,))
except KeyError:
raise MappingError("unknown flag: %s" % name)
return flags
def _decode(self, obj, context):
obj2 = FlagsContainer()
for name, value in self.flags.items():
setattr(obj2, name, bool(obj & value))
obj2[name] = bool(obj & value)
return obj2

class StringAdapter(Adapter):
Expand Down
11 changes: 10 additions & 1 deletion construct/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from struct import Struct as Packer
import sys
import collections

from construct.lib.py3compat import BytesIO, advance_iterator, bchr
from construct.lib import Container, ListContainer, LazyContainer
Expand Down Expand Up @@ -452,6 +453,7 @@ def _build(self, obj, stream, context):
def _sizeof(self, context):
return self.subcon._sizeof(context) * self.countfunc(context)


class Range(Subconstruct):
r"""
A range-array. The subcon will iterate between ``mincount`` to ``maxcount``
Expand Down Expand Up @@ -501,12 +503,14 @@ class Range(Subconstruct):
"""

__slots__ = ["mincount", "maxcout"]

def __init__(self, mincount, maxcout, subcon):
super(Range, self).__init__(subcon)
self.mincount = mincount
self.maxcout = maxcout
self._clear_flag(self.FLAG_COPY_CONTEXT)
self._set_flag(self.FLAG_DYNAMIC)

def _parse(self, stream, context):
obj = ListContainer()
c = 0
Expand All @@ -524,10 +528,13 @@ def _parse(self, stream, context):
except ConstructError:
if c < self.mincount:
raise RangeError("expected %d to %d, found %d" %
(self.mincount, self.maxcout, c), sys.exc_info()[1])
(self.mincount, self.maxcout, c))
stream.seek(pos)
return obj

def _build(self, obj, stream, context):
if not isinstance(obj, collections.Sequence):
raise RangeError("expected sequence type, found %s" % type(obj))
if len(obj) < self.mincount or len(obj) > self.maxcout:
raise RangeError("expected %d to %d, found %d" %
(self.mincount, self.maxcout, len(obj)))
Expand All @@ -549,9 +556,11 @@ def _build(self, obj, stream, context):
if cnt < self.mincount:
raise RangeError("expected %d to %d, found %d" %
(self.mincount, self.maxcout, len(obj)), sys.exc_info()[1])

def _sizeof(self, context):
raise SizeofError("can't calculate size")


class RepeatUntil(Subconstruct):
r"""
An array that repeats until the predicate indicates it to stop. Note that
Expand Down
54 changes: 35 additions & 19 deletions construct/lib/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Various containers.
"""

def recursion_lock(retval, lock_name = "__recursion_lock__"):

def recursion_lock(retval, lock_name="__recursion_lock__"):
def decorator(func):
def wrapper(self, *args, **kw):
if getattr(self, lock_name, False):
Expand All @@ -12,10 +13,13 @@ def wrapper(self, *args, **kw):
return func(self, *args, **kw)
finally:
setattr(self, lock_name, False)

wrapper.__name__ = func.__name__
return wrapper

return decorator


class Container(dict):
"""
A generic container of attributes.
Expand All @@ -24,33 +28,41 @@ class Container(dict):
"""
__slots__ = ["__keys_order__"]

def __init__(self, **kw):
def __init__(self, *args, **kw):
object.__setattr__(self, "__keys_order__", [])
for arg in args:
for k, v in arg.items():
self[k] = v
for k, v in kw.items():
self[k] = v

def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise AttributeError(name)

def __setitem__(self, key, val):
if key not in self:
self.__keys_order__.append(key)
self.__keys_order__.append(key)
dict.__setitem__(self, key, val)

def __delitem__(self, key):
dict.__delitem__(self, key)
self.__keys_order__.remove(key)

__delattr__ = __delitem__
__setattr__ = __setitem__

def clear(self):
dict.clear(self)
del self.__keys_order__[:]

def pop(self, key, *default):
val = dict.pop(self, key, *default)
self.__keys_order__.remove(key)
return val

def popitem(self):
k, v = dict.popitem(self)
self.__keys_order__.remove(k)
Expand Down Expand Up @@ -90,37 +102,43 @@ def _search(self, name, search_all):
pass
if search_all:
return items
else:
else:
return None

def search(self, name):
return self._search(name, False)

def search_all(self, name):
return self._search(name, True)

__update__ = update
__copy__ = copy

def __iter__(self):
return iter(self.__keys_order__)

iterkeys = __iter__

def itervalues(self):
return (self[k] for k in self.__keys_order__)

def iteritems(self):
return ((k, self[k]) for k in self.__keys_order__)

def keys(self):
return self.__keys_order__

def values(self):
return list(self.itervalues())

def items(self):
return list(self.iteritems())

def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self))

@recursion_lock("<...>")
def __pretty_str__(self, nesting = 1, indentation = " "):
def __pretty_str__(self, nesting=1, indentation=" "):
attrs = []
ind = indentation * nesting
for k, v in self.iteritems():
Expand All @@ -147,7 +165,7 @@ class FlagsContainer(Container):
"""

@recursion_lock("<...>")
def __pretty_str__(self, nesting = 1, indentation = " "):
def __pretty_str__(self, nesting=1, indentation=" "):
attrs = []
ind = indentation * nesting
for k in self.keys():
Expand All @@ -156,9 +174,9 @@ def __pretty_str__(self, nesting = 1, indentation = " "):
attrs.append(ind + k)
if not attrs:
return "%s()" % (self.__class__.__name__,)
attrs.insert(0, self.__class__.__name__+ ":")
attrs.insert(0, self.__class__.__name__ + ":")
return "\n".join(attrs)


class ListContainer(list):
"""
Expand All @@ -170,7 +188,7 @@ def __str__(self):
return self.__pretty_str__()

@recursion_lock("[...]")
def __pretty_str__(self, nesting = 1, indentation = " "):
def __pretty_str__(self, nesting=1, indentation=" "):
if not self:
return "[]"
ind = indentation * nesting
Expand All @@ -186,7 +204,7 @@ def __pretty_str__(self, nesting = 1, indentation = " "):
lines.append(indentation * (nesting - 1))
lines.append("]")
return "".join(lines)

def _search(self, name, search_all):
items = []
for item in self:
Expand All @@ -203,16 +221,15 @@ def _search(self, name, search_all):
return items
else:
return None

def search(self, name):
return self._search(name, False)

def search_all(self, name):
return self._search(name, True)


class LazyContainer(object):

__slots__ = ["subcon", "stream", "pos", "context", "_value"]

def __init__(self, subcon, stream, pos, context):
Expand All @@ -234,7 +251,7 @@ def __ne__(self, other):
def __str__(self):
return self.__pretty_str__()

def __pretty_str__(self, nesting = 1, indentation = " "):
def __pretty_str__(self, nesting=1, indentation=" "):
if self._value is NotImplemented:
text = "<unread>"
elif hasattr(self._value, "__pretty_str__"):
Expand All @@ -261,4 +278,3 @@ def _get_value(self):
value = property(_get_value)

has_value = property(lambda self: self._value is not NotImplemented)

15 changes: 15 additions & 0 deletions tests/lib/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,21 @@ def test_order(self):
c.update(words)
self.assertEqual([k for k, _ in words], list(c.keys()))

def test_dict_arg(self):
c = Container({'a': 1})
d = Container(a=1)
self.assertEqual(c, d)

def test_multiple_dict_args(self):
c = Container({'a': 1, 'b': 42}, {'b': 2})
d = Container(a=1, b=2)
self.assertEqual(c, d)

def test_dict_and_kw_args(self):
c = Container({'b': 42, 'c': 43}, {'a': 1, 'b': 2, 'c': 4}, c=3, d=4)
d = Container(a=1, b=2, c=3, d=4)
self.assertEqual(c, d)

class TestListContainer(unittest.TestCase):

def test_str(self):
Expand Down
32 changes: 27 additions & 5 deletions tests/test_adaptors.py → tests/test_adapters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
import unittest

from construct import FlagsAdapter, Byte, FlagsContainer, MappingError
from construct import Field, UBInt8
from construct import OneOf, NoneOf, HexDumpAdapter
from construct import ValidationError
from construct.protocols.layer3.ipv4 import IpAddress


class TestFlagsAdapter(unittest.TestCase):

def setUp(self):
self.fa = FlagsAdapter(Byte('HID_type'), {'feature': 4, 'output': 2, 'input': 1})

def test_trivial(self):
pass

def test_parse(self):
self.assertEqual(self.fa.parse(b'\x04'), FlagsContainer(feature=True, output=False, input=False))

def test_build(self):
self.assertEqual(self.fa.build(dict(feature=True, output=True, input=False)), b'\x06')
self.assertEqual(self.fa.build(dict(feature=True)), b'\x04')
self.assertEqual(self.fa.build(dict()), b'\x00')

def test_build_unknown_flag_given(self):
self.assertRaises(MappingError, self.fa.build, dict(unknown=True, feature=True))


class TestHexDumpAdapter(unittest.TestCase):

def setUp(self):
Expand All @@ -15,8 +36,8 @@ def test_trivial(self):
pass

def test_parse(self):
parsed = self.hda.parse(b"abcdef")
self.assertEqual(parsed, b"abcdef")
parsed = self.hda.parse(b'abcdef')
self.assertEqual(parsed, b'abcdef')

def test_build(self):
self.assertEqual(self.hda.build(b"abcdef"), b"abcdef")
Expand Down Expand Up @@ -73,11 +94,12 @@ def test_trivial(self):
pass

def test_parse(self):
self.assertEqual(self.ipa.parse(b"\x7f\x80\x81\x82"), "127.128.129.130")
self.assertEqual(self.ipa.parse(b"\x7f\x80\x81\x82"),
"127.128.129.130")

def test_build(self):
self.assertEqual(self.ipa.build("127.1.2.3"), b"\x7f\x01\x02\x03")
self.assertEqual(self.ipa.build("127.1.2.3"),
b"\x7f\x01\x02\x03")

def test_build_invalid(self):
self.assertRaises(ValueError, self.ipa.build, "300.1.2.3")

0 comments on commit 3136f4b

Please sign in to comment.