diff --git a/cwrap/codegen.py b/cwrap/codegen.py index 7cb5af9..c8b5c93 100644 --- a/cwrap/codegen.py +++ b/cwrap/codegen.py @@ -1,61 +1,52 @@ -from collections import defaultdict -from cStringIO import StringIO import os -import re import subprocess +import tempfile -import pycparser - -import cy_ast -import translate +import parser import renderers -class CodeGenerator(object): - - def __init__(self, config): - self.config = config - self._translator = translate.ASTTranslator() - self._extern_renderer = renderers.ExternRenderer() - - def _preprocess(self, header): - cmds = ['gcc'] - for inc_dir in self.config.include_dirs: - cmds.append('-I' + inc_dir) - cmds.append('-E') - cmds.append(header.path) - p = subprocess.Popen(cmds, stdout=subprocess.PIPE) - c_code, _ = p.communicate() - - # we need to remove any gcc __attribute__ declarations - # from the code as this will cause PyCParser to fail. - c_code = re.sub('__attribute__\(\(.*?\)\)', '', c_code) - - return c_code - - def _parse(self, c_code): - parser = pycparser.CParser() - ast = parser.parse(c_code) - return ast - - def _translate(self, c_ast): - module_ast = self._translator.translate(c_ast) - return module_ast - - def _render_extern(self, module_ast, header): - return self._extern_renderer.render(module_ast, header) - - def generate(self): - save_dir = self.config.save_dir - for header in self.config.headers: - c_code = self._preprocess(header) - c_ast = self._parse(c_code) - module_ast = self._translate(c_ast) - - extern_code = self._render_extern(module_ast, header) - extern_name = '_' + header.mod_name + '.pxd' - extern_save_pth = os.path.join(save_dir, extern_name) - - with open(extern_save_pth, 'w') as f: - f.write(extern_code) +def _parse(header, include_dirs): + """ Parse the given header file into and ast. The include + dirs are passed along to gccxml. + + """ + # A temporary file to store the xml generated by gccxml + xml_file = tempfile.NamedTemporaryFile(suffix='.xml', delete=False) + xml_file.close() + + # buildup the gccxml command + cmds = ['gccxml'] + for inc_dir in include_dirs: + cmds.append('-I' + inc_dir) + cmds.append(header.path) + cmds.append('-fxml=%s' % xml_file.name) + + # we pipe stdout so the preprocessing doesn't dump to the + # shell. We really don't care about it. + p = subprocess.Popen(cmds, stdout=subprocess.PIPE) + cpp, _ = p.communicate() + + # Parse the xml into the ast then delete the temp file + ast = parser.parse(xml_file.name) + os.remove(xml_file.name) + + return ast + + +def _render_extern(ast, header, config): + renderer = renderers.ExternRenderer() + return renderer.render(ast, header.path, config) + + +def generate(config): + save_dir = config.save_dir + include_dirs = config.include_dirs + for header in config.headers: + items = _parse(header, include_dirs) + extern_code = _render_extern(items, header, config) + extern_path = os.path.join(save_dir, header.pxd + '.pxd') + with open(extern_path, 'wb') as f: + f.write(extern_code) + diff --git a/cwrap/config.py b/cwrap/config.py index 96c3030..ea9a311 100644 --- a/cwrap/config.py +++ b/cwrap/config.py @@ -3,14 +3,21 @@ class Header(object): - def __init__(self, path, mod_name=None): - self.path = path - self.header_name = os.path.split(path)[-1] + def __init__(self, path, pxd=None, pyx=None): + self.path = os.path.abspath(path) + self.header_name = os.path.split(self.path)[-1] + + mod_base = self.header_name.rstrip('.h') + + if pxd is None: + self.pxd = '_' + mod_base + else: + self.pxd = extern - if mod_name is None: - self.mod_name = self.header_name.rstrip('.h') + if pyx is None: + self.pyx = mod_base else: - self.mod_name = mod_name + self.pyx = pyx class Config(object): @@ -19,5 +26,16 @@ def __init__(self, include_dirs=None, save_dir=None, headers=None): self.include_dirs = include_dirs or [] self.save_dir = save_dir or os.getcwd() self.headers = headers or [] + + self._header_map = {} + for header in self.headers: + self._header_map[header.path] = header + + def header(self, header_path): + return self._header_map[header_path] + def pxd_name(self, header_path): + return self._header_map[header_path].pxd + def pyx_name(self, header_path): + return self._header_map[header_path].pyx diff --git a/cwrap/cy_ast.py b/cwrap/cy_ast.py index 0bfd89a..19b6134 100644 --- a/cwrap/cy_ast.py +++ b/cwrap/cy_ast.py @@ -1,101 +1,5 @@ -#------------------------------------------------------------------------------ -# Builtin types -#------------------------------------------------------------------------------ - -# This dict is populated by the CTypeMeta class -C_TYPES = {} - - -class CTypeMeta(type): - - def __new__(meta, cls_name, bases, cls_dict): - cls = type.__new__(meta, cls_name, bases, cls_dict) - C_TYPES[cls.c_name] = cls - return cls - - -class CType(object): - - __metaclass__ = CTypeMeta - - c_name = '' - - @classmethod - def cast(cls): - return '<%s>' % self.c_name - - @classmethod - def object_var_to_c(cls, name): - return self.cast() + name - - @classmethod - def c_var_to_object(cls, name): - return '' + name - - -class Void(CType): - c_name = 'void' - - -class Int(CType): - c_name = 'int' - - -class UInt(CType): - c_name = 'unsigned int' - - -class Short(CType): - c_name = 'short' - - -class UShort(CType): - c_name = 'unsigned short' - - -class Char(CType): - c_name = 'char' - - -class UChar(CType): - c_name = 'unsigned char' - - -class Long(CType): - c_name = 'long' - - -class ULong(CType): - c_name = 'unsigned long' - - -class LongLong(CType): - c_name = 'long long' - - -class ULongLong(CType): - c_name = 'unsigned long long' - - -class Float(CType): - c_name = 'float' - - -class Double(CType): - c_name = 'double' - - -class LongDouble(CType): - c_name = 'long double' - - -#------------------------------------------------------------------------------ -# Ast nodes -#------------------------------------------------------------------------------ - - class ASTNode(object): def __init__(self, *args, **kwargs): @@ -104,6 +8,9 @@ def __init__(self, *args, **kwargs): def init(self, *args, **kwargs): pass + + def refs(self): + return [] class Typedef(ASTNode): @@ -111,7 +18,10 @@ class Typedef(ASTNode): def init(self, name, typ): self.name = name self.typ = typ - + + def refs(self): + return [self.typ] + class FundamentalType(ASTNode): @@ -120,6 +30,9 @@ def init(self, name, size, align): self.size = size self.align = align + def refs(self): + return [] + class CvQualifiedType(ASTNode): @@ -128,6 +41,9 @@ def init(self, typ, const, volatile): self.const = const self.volatile = volatile + def refs(self): + return [self.typ] + class Ignored(ASTNode): @@ -142,6 +58,9 @@ def fixup_argtypes(self, typemap): def add_argument(self, argument): self.arguments.append(argument) + def refs(self): + return self.arguments + class Field(ASTNode): @@ -151,6 +70,9 @@ def init(self, name, typ, bits, offset): self.bits = bits self.offset = offset + def refs(self): + return [self.typ] + class Struct(ASTNode): @@ -165,6 +87,9 @@ def init(self, name, align, members, bases, size): def opaque(self): return len(self.members) == 0 + def refs(self): + return self.members + class Union(ASTNode): @@ -179,12 +104,18 @@ def init(self, name, align, members, bases, size): def opaque(self): return len(self.members) == 0 + def refs(self): + return self.members + class EnumValue(ASTNode): def init(self, name, value): self.name = name self.value = value + + def refs(self): + return [] class Enumeration(ASTNode): @@ -201,6 +132,9 @@ def add_value(self, val): @property def opaque(self): return len(self.values) == 1 + + def refs(self): + return self.values class PointerType(ASTNode): @@ -209,14 +143,20 @@ def init(self, typ, size, align): self.typ = typ self.size = size self.align = align - + + def refs(self): + return [self.typ] + class ArrayType(ASTNode): def init(self, typ, min, max): self.typ = typ - self.min = int(min.rstrip('lu')) - self.max = int(max.rstrip('lu')) + self.min = min + self.max = max + + def refs(self): + return [self.typ] class Argument(ASTNode): @@ -225,6 +165,9 @@ def init(self, typ, name): self.typ = typ self.name = name + def refs(self): + return [self.typ] + class Function(ASTNode): @@ -242,6 +185,9 @@ def fixup_argtypes(self, typemap): def add_argument(self, argument): self.arguments.append(argument) + def refs(self): + return [self.returns] + self.arguments + class FunctionType(ASTNode): @@ -256,6 +202,9 @@ def fixup_argtypes(self, typemap): def add_argument(self, argument): self.arguments.append(argument) + + def refs(self): + return [self.returns] + self.arguments class OperatorFunction(ASTNode): @@ -264,6 +213,9 @@ def init(self, name, returns): self.name = name self.returns = returns + def refs(self): + return [self.returns] + class Macro(ASTNode): @@ -279,6 +231,9 @@ def init(self, name, value, typ=None): self.name = name self.value = value self.typ = typ + + def refs(self): + return [self.typ] class File(ASTNode): @@ -293,3 +248,8 @@ def init(self, name, typ, init): self.name = name self.typ = typ self.init = init + + def refs(self): + return [self.typ] + + diff --git a/cwrap/parser.py b/cwrap/parser.py index 7f731ed..ee85e16 100644 --- a/cwrap/parser.py +++ b/cwrap/parser.py @@ -10,6 +10,9 @@ def MAKE_NAME(name): + """ Converts a mangled C++ name to a valid python identifier. + + """ name = name.replace('$', 'DOLLAR') name = name.replace('.', 'DOT') if name.startswith('__'): @@ -23,36 +26,70 @@ def MAKE_NAME(name): def CHECK_NAME(name): + """ Checks if `name` is a valid Python identifier. Returns + `name` on success, None on failure. + + """ if WORDPAT.match(name): return name return None class GCCXMLParser(object): + """ Parses a gccxml file into a list of file-level ast nodes. - has_values = set(['Enumeration', 'Function', 'FunctionType', - 'OperatorFunction', 'Method', 'Constructor', - 'Destructor', 'OperatorMethod']) + """ + # xml element types that have xml subelements. For example, + # function arguments are subelements of a function, but struct + # fields are their own toplevel xml elements + has_subelements = set(['Enumeration', 'Function', 'FunctionType', + 'OperatorFunction', 'Method', 'Constructor', + 'Destructor', 'OperatorMethod']) def __init__(self, *args): + # `context` acts like stack where parent nodes are pushed + # before visiting children self.context = [] + + # `all` maps the unique ids from the xml to the ast + # node that was generated by the element. This is used + # after all nodes have been generated to go back and + # hook up dependent nodes. self.all = {} + + # XXX - what does this do? self.cpp_data = {} + # `cdata` is used as temporary storage while elements + # are being processed. + self.cdata = None + + # `cvs_revision` stores the gccxml version in use. + self.cvs_revision = None + #-------------------------------------------------------------------------- # Parsing entry points #-------------------------------------------------------------------------- def parse(self, xmlfile): + """ Parsing entry point. `xmlfile` is a filename or a file + object. + + """ for event, node in cElementTree.iterparse(xmlfile, events=('start', 'end')): if event == 'start': - self.startElement(node.tag, dict(node.items())) + self.start_element(node.tag, dict(node.items())) else: if node.text: - self.characters(node.text) - self.endElement(node.tag) + self.visit_Characters(node.text) + self.end_element(node.tag) node.clear() - def startElement(self, name, attrs): + def start_element(self, name, attrs): + """ XML start element handler. Generates and calls the visitor + method name, registers the resulting node's id, and + sets the location on the node. + + """ # find and call the handler for this element mth = getattr(self, 'visit_' + name, None) if mth is None: @@ -62,7 +99,7 @@ def startElement(self, name, attrs): # Record the result and register the the id, which is # used in the _fixup_* methods. Some elements don't have - # and id, so we create our own. + # an id, so we create our own. if result is not None: location = attrs.get('location', None) if location is not None: @@ -73,42 +110,53 @@ def startElement(self, name, attrs): else: self.all[id(result)] = result - # if this element has children, push onto the context - if name in self.has_values: + # if this element has subelements, push it onto the context + # since the next elements will be it's children. + if name in self.has_subelements: self.context.append(result) - cdata = None - def endElement(self, name): - # if this element has children, pop the context - if name in self.has_values: + def end_element(self, name): + """ XML end element handler. + + """ + # if this element has subelements, then it will have + # been push onto the stack and needs to be removed. + if name in self.has_subelements: self.context.pop() self.cdata = None def unhandled_element(self, name, attrs): + """ Handler for element nodes where a real handler is not + found. + + """ print 'Unhandled element `%s`.' % name #-------------------------------------------------------------------------- # Ignored elements and do-nothing handlers #-------------------------------------------------------------------------- def visit_Ignored(self, attrs): + """ Ignored elements are those which we don't care about, + but need to keep in place because we care about their + children. + + """ name = attrs.get('name', None) - if not name: - name = attrs.get('mangled', 'UNDEFINED') + if name is None: + name = attrs.get('mangled', None) + if name is None: + name = 'UNDEFINED' + else: + name = MAKE_NAME(name) return cy_ast.Ignored(name) - def _fixup_Ignored(self, const): - pass - visit_Method = visit_Ignored visit_Constructor = visit_Ignored visit_Destructor = visit_Ignored visit_OperatorMethod = visit_Ignored - _fixup_Method = _fixup_Ignored - _fixup_Constructor = _fixup_Ignored - _fixup_Destructor = _fixup_Ignored - _fixup_OperatorMethod = _fixup_Ignored - + # These node types are ignored becuase we don't need anything + # at all from them. visit_Class = lambda *args: None visit_Namespace = lambda *args: None visit_Base = lambda *args: None @@ -117,58 +165,52 @@ def _fixup_Ignored(self, const): #-------------------------------------------------------------------------- # Revision Handler #-------------------------------------------------------------------------- - cvs_revision = None def visit_GCC_XML(self, attrs): + """ Handles the versioning info from the gccxml version. + + """ rev = attrs['cvs_revision'] self.cvs_revision = tuple(map(int, rev.split('.'))) - + #-------------------------------------------------------------------------- - # Real element handlers + # Text handlers #-------------------------------------------------------------------------- + def visit_Characters(self, content): + """ The character handler which is called after each xml + element has been processed. + + """ + if self.cdata is not None: + self.cdata.append(content) + def visit_CPP_DUMP(self, attrs): - name = attrs['name'] + """ Gathers preprocessor elements like macros and defines. + + """ # Insert a new list for each named section into self.cpp_data, # and point self.cdata to it. self.cdata will be set to None # again at the end of each section. + name = attrs['name'] self.cpp_data[name] = self.cdata = [] - - def characters(self, content): - if self.cdata is not None: - self.cdata.append(content) - + + #-------------------------------------------------------------------------- + # Node element handlers + #-------------------------------------------------------------------------- def visit_File(self, attrs): name = attrs['name'] - if sys.platform == 'win32' and ' ' in name: - # On windows, convert to short filename if it contains blanks - from ctypes import windll, create_unicode_buffer, sizeof, WinError - buf = create_unicode_buffer(512) - if windll.kernel32.GetShortPathNameW(name, buf, sizeof(buf)): - name = buf.value return cy_ast.File(name) - def _fixup_File(self, f): - pass - def visit_Variable(self, attrs): name = attrs['name'] - if name.startswith('cpp_sym_'): - # XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXx fix me! - name = name[len('cpp_sym_'):] init = attrs.get('init', None) typ = attrs['type'] return cy_ast.Variable(name, typ, init) - def _fixup_Variable(self, t): - t.typ = self.all[t.typ] - def visit_Typedef(self, attrs): name = attrs['name'] typ = attrs['type'] return cy_ast.Typedef(name, typ) - - def _fixup_Typedef(self, t): - t.typ = self.all[t.typ] - + def visit_FundamentalType(self, attrs): name = attrs['name'] if name == 'void': @@ -178,84 +220,48 @@ def visit_FundamentalType(self, attrs): align = attrs['align'] return cy_ast.FundamentalType(name, size, align) - def _fixup_FundamentalType(self, t): - pass - def visit_PointerType(self, attrs): typ = attrs['type'] size = attrs['size'] align = attrs['align'] return cy_ast.PointerType(typ, size, align) - def _fixup_PointerType(self, p): - p.typ = self.all[p.typ] - visit_ReferenceType = visit_PointerType - _fixup_ReferenceType = _fixup_PointerType - + def visit_ArrayType(self, attrs): - # type, min?, max? + # min, max are the min and max array indices typ = attrs['type'] min = attrs['min'] max = attrs['max'] if max == 'ffffffffffffffff': max = '-1' + min = int(min.rstrip('lu')) + max = int(max.rstrip('lu')) return cy_ast.ArrayType(typ, min, max) - def _fixup_ArrayType(self, a): - a.typ = self.all[a.typ] - def visit_CvQualifiedType(self, attrs): - # id, type, [const|volatile] typ = attrs['type'] const = attrs.get('const', None) volatile = attrs.get('volatile', None) return cy_ast.CvQualifiedType(typ, const, volatile) - - def _fixup_CvQualifiedType(self, c): - c.typ = self.all[c.typ] - + def visit_Function(self, attrs): - # name, returns, extern, attributes name = attrs['name'] returns = attrs['returns'] attributes = attrs.get('attributes', '').split() extern = attrs.get('extern') return cy_ast.Function(name, returns, attributes, extern) - def _fixup_Function(self, func): - func.returns = self.all[func.returns] - func.fixup_argtypes(self.all) - def visit_FunctionType(self, attrs): - # id, returns, attributes returns = attrs['returns'] attributes = attrs.get('attributes', '').split() return cy_ast.FunctionType(returns, attributes) - - def _fixup_FunctionType(self, func): - func.returns = self.all[func.returns] - func.fixup_argtypes(self.all) - + def visit_OperatorFunction(self, attrs): - # name, returns, extern, attributes name = attrs['name'] returns = attrs['returns'] return cy_ast.OperatorFunction(name, returns) - def _fixup_OperatorFunction(self, func): - func.returns = self.all[func.returns] - - #def Method(self, attrs): - # # name, virtual, pure_virtual, returns - # name = attrs['name'] - # returns = attrs['returns'] - # return typedesc.Method(name, returns) - - #def _fixup_Method(self, m): - # m.returns = self.all[m.returns] - # m.fixup_argtypes(self.all) - def visit_Argument(self, attrs): parent = self.context[-1] if parent is not None: @@ -265,17 +271,13 @@ def visit_Argument(self, attrs): parent.add_argument(arg) def visit_Enumeration(self, attrs): - # id, name - name = attrs['name'] - # If the name isn't a valid Python identifier, create an unnamed enum - name = CHECK_NAME(name) + # If the name isn't a valid Python identifier, + # create an unnamed enum + name = CHECK_NAME(attrs['name']) size = attrs['size'] align = attrs['align'] return cy_ast.Enumeration(name, size, align) - - def _fixup_Enumeration(self, e): - pass - + def visit_EnumValue(self, attrs): parent = self.context[-1] if parent is not None: @@ -284,11 +286,7 @@ def visit_EnumValue(self, attrs): val = cy_ast.EnumValue(name, value) parent.add_value(val) - def _fixup_EnumValue(self, e): - pass - def visit_Struct(self, attrs): - # id, name, members name = attrs.get('name') if name is None: name = MAKE_NAME(attrs['mangled']) @@ -297,11 +295,7 @@ def visit_Struct(self, attrs): align = attrs['align'] size = attrs.get('size') return cy_ast.Struct(name, align, members, bases, size) - - def _fixup_Struct(self, s): - s.members = [self.all[m] for m in s.members] - s.bases = [self.all[b] for b in s.bases] - + def visit_Union(self, attrs): name = attrs.get('name') if name is None: @@ -312,33 +306,97 @@ def visit_Union(self, attrs): size = attrs.get('size') return cy_ast.Union(name, align, members, bases, size) - def _fixup_Union(self, u): - u.members = [self.all[m] for m in u.members] - u.bases = [self.all[b] for b in u.bases] - def visit_Field(self, attrs): - # name, type name = attrs['name'] typ = attrs['type'] bits = attrs.get('bits', None) offset = attrs.get('offset') return cy_ast.Field(name, typ, bits, offset) + #-------------------------------------------------------------------------- + # Fixup handlers + #-------------------------------------------------------------------------- + + # The fixup handlers use the ids save on the node attrs to lookup + # the replacement node from the storage, then do the swapout. There + # must be a fixup handler (even if its pass-thru) for each node + # handler that returns a node object. + + def _fixup_File(self, f): + pass + + def _fixup_Variable(self, t): + t.typ = self.all[t.typ] + + def _fixup_Typedef(self, t): + t.typ = self.all[t.typ] + + def _fixup_FundamentalType(self, t): + pass + + def _fixup_PointerType(self, p): + p.typ = self.all[p.typ] + + _fixup_ReferenceType = _fixup_PointerType + + def _fixup_ArrayType(self, a): + a.typ = self.all[a.typ] + + def _fixup_CvQualifiedType(self, c): + c.typ = self.all[c.typ] + + def _fixup_Function(self, func): + func.returns = self.all[func.returns] + func.fixup_argtypes(self.all) + + def _fixup_FunctionType(self, func): + func.returns = self.all[func.returns] + func.fixup_argtypes(self.all) + + def _fixup_OperatorFunction(self, func): + func.returns = self.all[func.returns] + + def _fixup_Enumeration(self, e): + pass + + def _fixup_EnumValue(self, e): + pass + + def _fixup_Struct(self, s): + s.members = [self.all[m] for m in s.members] + s.bases = [self.all[b] for b in s.bases] + + def _fixup_Union(self, u): + u.members = [self.all[m] for m in u.members] + u.bases = [self.all[b] for b in u.bases] + def _fixup_Field(self, f): f.typ = self.all[f.typ] def _fixup_Macro(self, m): pass + def _fixup_Ignored(self, const): + pass + + _fixup_Method = _fixup_Ignored + _fixup_Constructor = _fixup_Ignored + _fixup_Destructor = _fixup_Ignored + _fixup_OperatorMethod = _fixup_Ignored + #-------------------------------------------------------------------------- # Post parsing helpers #-------------------------------------------------------------------------- def get_macros(self, text): + """ Attempts to extract the macros from a piece of text + and converts it to a Macro node containing the name, + args, and body. + + """ if text is None: return - - # preprocessor definitions that look like macros with one - # or more arguments + + # join and split so we can accept a list or string. text = ''.join(text) for m in text.splitlines(): name, body = m.split(None, 1) @@ -347,13 +405,15 @@ def get_macros(self, text): self.all[name] = cy_ast.Macro(name, args, body) def get_aliases(self, text, namespace): + """ Attemps to extract defined aliases of the form + #define A B and store them in an Alias node. + + """ if text is None: return - # preprocessor definitions that look like aliases: - # #define A B - text = ''.join(text) aliases = {} + text = ''.join(text) for a in text.splitlines(): name, value = a.split(None, 1) a = cy_ast.Alias(name, value) @@ -372,6 +432,12 @@ def get_aliases(self, text, namespace): pass def get_result(self): + """ After parsing, call this method to retrieve the results + as a list of AST nodes. This list will contain *all* nodes + in the xml file which will include a bunch of builtin and + internal stuff that you wont want. + + """ # Drop some warnings for early gccxml versions import warnings if self.cvs_revision is None: @@ -382,7 +448,7 @@ def get_result(self): # Gather any macros. self.get_macros(self.cpp_data.get('functions')) - # Pass through all the items, hooking up the appropriate + # Walk through all the items, hooking up the appropriate # links by replacing the id tags with the actual objects remove = [] for name, node in self.all.items(): @@ -401,27 +467,24 @@ def get_result(self): for n in remove: del self.all[n] - # Now we can build the namespace, keeping only the nodes + # Now we can build the namespace composed only of the nodes # in which we're interested. interesting = (cy_ast.Typedef, cy_ast.Enumeration, cy_ast.EnumValue, cy_ast.Function, cy_ast.Struct, cy_ast.Union, cy_ast.Variable, cy_ast.Macro, cy_ast.Alias) + result = [] namespace = {} for node in self.all.values(): if not isinstance(node, interesting): - continue # we don't want these + continue + result.append(node) name = getattr(node, 'name', None) if name is not None: namespace[name] = node self.get_aliases(self.cpp_data.get('aliases'), namespace) - result = [] - for node in self.all.values(): - if isinstance(node, interesting): - result.append(node) - return result @@ -430,12 +493,4 @@ def parse(xmlfile): parser = GCCXMLParser() parser.parse(xmlfile) items = parser.get_result() - in_name = os.path.split(xmlfile)[-1].replace('xml', 'h') - res = [] - for item in items: - if item.location: - out_name = os.path.split(item.location[0])[-1] - if out_name == in_name: - res.append(item) - res.sort(key=lambda item: int(item.location[1])) - return res + return items diff --git a/cwrap/renderers.py b/cwrap/renderers.py index 90834e4..324d3b3 100644 --- a/cwrap/renderers.py +++ b/cwrap/renderers.py @@ -64,9 +64,6 @@ def _gen_imports(self): import_lines.append('cimport %s as %s' % (module, name)) else: import_lines.append('cimport %s' % module) - - if import_lines: - import_lines.append('\n') # cimports from cimport_from_items = sorted( self._cimports_from.iteritems() ) @@ -81,9 +78,6 @@ def _gen_imports(self): sub_txt = ', '.join(sub_lines) import_lines.append('from %s cimport %s' % (module, sub_txt)) - if import_lines: - import_lines.append('\n') - # cimports import_items = sorted( self._imports.iteritems() ) for module, as_names in import_items: @@ -93,9 +87,6 @@ def _gen_imports(self): else: import_lines.append('import %s' % module) - if import_lines: - import_lines.append('\n') - # cimports from import_from_items = sorted( self._imports_from.iteritems() ) for module, impl_dct in import_from_items: @@ -125,19 +116,45 @@ class ExternRenderer(object): def __init__(self): self.context = [None] self.code = None + self.header_path = None + self.config = None - def render(self, items, header_name): + def render(self, items, header_path, config): self.context = [None] self.code = Code() + self.header_path = header_path + self.config = config + + # filter for the items that are toplevel in this header + # and order them by their appearance + toplevel = filter(lambda item: item.location[0] == header_path, items) + toplevel.sort(key=lambda item: int(item.location[1])) + # work out any imports we need + #for item in toplevel: + # self.resolve_imports(item, set()) + + header_name = self.config.header(header_path).header_name self.code.write_i('cdef extern from "%s":\n\n' % header_name) self.code.indent() - for item in items: + for item in toplevel: self.visit(item) self.code.dedent() return self.code.code() - + + def resolve_imports(self, node, visited): + if isinstance(node, cy_ast.ASTNode): + if not isinstance(node, cy_ast.FundamentalType): + if node.location: + if node.location[0] != self.header_path: + pxd = self.config.pxd_name(node.location[0]) + self.code.add_cimport_from(pxd, '*') + visited.add(node) + for snode in node.refs(): + if snode not in visited: + self.resolve_imports(snode, visited) + def visit(self, node): self.context.append(node) method = 'visit_' + node.__class__.__name__ @@ -258,16 +275,19 @@ def visit_Function(self, function): def visit_Argument(self, argument): name = argument.name typ = argument.typ + + if isinstance(typ, cy_ast.CvQualifiedType): + typ = typ.typ + if isinstance(typ, (cy_ast.Typedef, cy_ast.FundamentalType, cy_ast.Enumeration, cy_ast.Struct)): typ_name = typ.name elif isinstance(typ, (cy_ast.PointerType, cy_ast.ArrayType)): typ_name, name = self.apply_modifier(typ, name) - elif isinstance(typ, cy_ast.CvQualifiedType): - typ_name = typ.name else: print 'unhandled argument type node: `%s`' % typ typ_name = UNDEFINED + if name is not None: self.code.write('%s %s' % (typ_name, name)) else: @@ -312,6 +332,9 @@ def apply_modifier(self, node, name): stack = [] typ = node + if name is None: + name = '' + while isinstance(typ, (cy_ast.PointerType, cy_ast.ArrayType, cy_ast.CvQualifiedType)): if isinstance(typ, (cy_ast.PointerType, cy_ast.ArrayType)): @@ -331,7 +354,10 @@ def apply_modifier(self, node, name): name = '(' + name + ')' if isinstance(node, cy_ast.PointerType): - name = '*' + name + try: + name = '*' + name + except: + import pdb; pdb.set_trace() elif isinstance(node, cy_ast.ArrayType): max = node.max if max is None: