diff --git a/pfp/interp.py b/pfp/interp.py index b495cce..49e938d 100644 --- a/pfp/interp.py +++ b/pfp/interp.py @@ -52,7 +52,9 @@ def _pfp__init(self, stream): self._pfp__node.args, scope, self, None ) param_list = params.instantiate(scope, struct_args, self._pfp__interp) - super(self.__class__, self)._pfp__init(stream) + + if hasattr(super(self.__class__, self), "_pfp__init"): + super(self.__class__, self)._pfp__init(stream) new_class = type( struct_cls.__name__ + "_", (struct_cls,), {"_pfp__init": _pfp__init} @@ -84,13 +86,28 @@ def StructUnionTypeRef(curr_scope, typedef_name, refd_name, interp, node): elif isinstance(node, AST.Union): cls = fields.Union - def __new__(self, *args, **kwargs): + def __new__(cls_, *args, **kwargs): refd_type = curr_scope.get_type(refd_name) if refd_type is None: refd_node = node else: refd_node = refd_type._pfp__node - return StructUnionDef(typedef_name, interp, refd_node)(*args, **kwargs) + + def merged_init(self, stream): + if six.PY3: + cls_._pfp__init(self, stream) + else: + cls_._pfp__init.__func__(self, stream) + self._pfp__init_orig(stream) + + overrides = {} + if hasattr(cls_, "_pfp__init"): + overrides["_pfp__init"] = merged_init + + res = base_cls = StructUnionDef( + typedef_name, interp, refd_node, overrides=overrides, + ) + return res(*args, **kwargs) new_class = type( typedef_name, @@ -102,13 +119,16 @@ def __new__(self, *args, **kwargs): return new_class - -def StructUnionDef(typedef_name, interp, node): +def StructUnionDef(typedef_name, interp, node, overrides=None, cls=None): + if overrides is None: + overrides = {} if isinstance(node, AST.Struct): - cls = fields.Struct + if cls is None: + cls = fields.Struct decls = StructDecls(node.decls, node.coord) elif isinstance(node, AST.Union): - cls = fields.Union + if cls is None: + cls = fields.Union decls = UnionDecls(node.decls, node.coord) # this is so that we can have all nested structs added to @@ -117,7 +137,11 @@ def StructUnionDef(typedef_name, interp, node): # the new struct to not be added to its parent, and the user would # not be able to see how far the script got def __init__(self, stream=None, metadata_processor=None, do_init=True): - cls.__init__(self, stream, metadata_processor=metadata_processor) + cls.__init__( + self, + stream, + metadata_processor=metadata_processor, + ) if do_init: self._pfp__init(stream) @@ -125,15 +149,22 @@ def __init__(self, stream=None, metadata_processor=None, do_init=True): def _pfp__init(self, stream): self._pfp__interp._handle_node(decls, ctxt=self, stream=stream) + cls_members = { + "__init__": __init__, + "_pfp__init": _pfp__init, + "_pfp__node": node, + "_pfp__interp": interp, + } + + for k, v in six.iteritems(overrides or {}): + if k in cls_members: + cls_members[k + "_orig"] = cls_members[k] + cls_members[k] = v + new_class = type( typedef_name, (cls,), - { - "__init__": __init__, - "_pfp__init": _pfp__init, - "_pfp__node": node, - "_pfp__interp": interp, - }, + cls_members, ) return new_class diff --git a/requirements.txt b/requirements.txt index 819a4aa..e097781 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -py010parser>=0.1.15 +py010parser>=0.1.17 six>=1.10.0,<2.0.0 -intervaltree>=3.0.2,<4.0.0 \ No newline at end of file +intervaltree>=3.0.2,<4.0.0 diff --git a/tests/test_struct_union.py b/tests/test_struct_union.py index bbdf8ab..d0197e6 100644 --- a/tests/test_struct_union.py +++ b/tests/test_struct_union.py @@ -94,7 +94,7 @@ def test_struct_vit9696_5(self): LittleEndian(); ME s; """, - debug=True, + debug=False, ) assert dom.s.magic == "\x00\x01\x02\x03" assert dom.s.filesize == 0x03020100 @@ -239,6 +239,30 @@ def test_struct_with_parameters3(self): self.assertEqual(dom.l.c[1], 2) self.assertEqual(dom.l.c[2], 3) + def test_typedefd_struct_with_parameters(self): + dom = self._test_parse_build( + "\x01\x02\x03\x04\x01\x02\x03", + """ + struct TEST_STRUCT(int arraySize, int arraySize2) + { + uchar b[arraySize]; + uchar c[arraySize2]; + }; + local int bytes = 4; + typedef struct TEST_STRUCT NEW_STRUCT; + NEW_STRUCT l(bytes, 3); + """, + ) + self.assertEqual(len(dom.l.b), 4) + self.assertEqual(dom.l.b[0], 1) + self.assertEqual(dom.l.b[1], 2) + self.assertEqual(dom.l.b[2], 3) + self.assertEqual(dom.l.b[3], 4) + self.assertEqual(len(dom.l.c), 3) + self.assertEqual(dom.l.c[0], 1) + self.assertEqual(dom.l.c[1], 2) + self.assertEqual(dom.l.c[2], 3) + def test_struct_decl_with_struct_keyword(self): dom = self._test_parse_build( "ABCD",