Skip to content

Commit

Permalink
- repair issue in sax.ElementTreeContentHandler
Browse files Browse the repository at this point in the history
whereby attributes passed to startElement() would be mis-interpreted
as containing a namespace attribute, leading to a TypeError,
as well as where attributes with namespaces wouldn't be split
up correctly when passed to startElement().
  • Loading branch information
zzzeek committed Mar 29, 2013
1 parent d829718 commit 3222e75
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
14 changes: 12 additions & 2 deletions src/lxml/sax.py
Expand Up @@ -26,6 +26,12 @@ def _getNsTag(tag):
else: else:
return (None, tag) return (None, tag)


def _getAttrNSTag(attr_key):
if ":" in attr_key:
return tuple(attr_key.split(":", 1))
else:
return (None, attr_key)

class ElementTreeContentHandler(ContentHandler): class ElementTreeContentHandler(ContentHandler):
"""Build an lxml ElementTree from SAX events. """Build an lxml ElementTree from SAX events.
""" """
Expand All @@ -45,7 +51,7 @@ def _get_etree(self):
return ElementTree(self._root) return ElementTree(self._root)


etree = property(_get_etree, doc=_get_etree.__doc__) etree = property(_get_etree, doc=_get_etree.__doc__)

def setDocumentLocator(self, locator): def setDocumentLocator(self, locator):
pass pass


Expand Down Expand Up @@ -127,6 +133,10 @@ def endElementNS(self, ns_name, qname):
raise SaxError("Unexpected element closed: " + el_tag) raise SaxError("Unexpected element closed: " + el_tag)


def startElement(self, name, attributes=None): def startElement(self, name, attributes=None):
if attributes is not None:
attributes = dict(
[(_getAttrNSTag(k), v) for k, v in attributes.items()]
)
self.startElementNS((None, name), name, attributes) self.startElementNS((None, name), name, attributes)


def endElement(self, name): def endElement(self, name):
Expand All @@ -143,7 +153,7 @@ def characters(self, data):
last_element.text = (last_element.text or '') + data last_element.text = (last_element.text or '') + data


ignorableWhitespace = characters ignorableWhitespace = characters



class ElementTreeProducer(object): class ElementTreeProducer(object):
"""Produces SAX events for an element and children. """Produces SAX events for an element and children.
Expand Down
38 changes: 36 additions & 2 deletions src/lxml/tests/test_sax.py
Expand Up @@ -216,6 +216,40 @@ def test_etree_sax_no_ns(self):
self.assertEqual('b', root[0].tag) self.assertEqual('b', root[0].tag)
self.assertEqual('c', root[1].tag) self.assertEqual('c', root[1].tag)


def test_etree_sax_no_ns_attributes(self):
handler = sax.ElementTreeContentHandler()
handler.startDocument()
handler.startElement('a', {"attr_a1": "a1"})
handler.startElement('b', {"attr_b1": "b1"})
handler.endElement('b')
handler.endElement('a')
handler.endDocument()

new_tree = handler.etree
root = new_tree.getroot()
self.assertEqual('a', root.tag)
self.assertEqual('b', root[0].tag)
self.assertEqual('a1', root.attrib["attr_a1"])
self.assertEqual('b1', root[0].attrib["attr_b1"])

def test_etree_sax_ns_attributes(self):
handler = sax.ElementTreeContentHandler()
handler.startDocument()

handler.startElement('a', {"blaA:attr_a1": "a1"})
handler.startElement('b', {"blaA:attr_b1": "b1"})
handler.endElement('b')
handler.endElement('a')

handler.endDocument()

new_tree = handler.etree
root = new_tree.getroot()
self.assertEqual('a', root.tag)
self.assertEqual('b', root[0].tag)
self.assertEqual('a1', root.attrib["{blaA}attr_a1"])
self.assertEqual('b1', root[0].attrib["{blaA}attr_b1"])

def test_etree_sax_error(self): def test_etree_sax_error(self):
handler = sax.ElementTreeContentHandler() handler = sax.ElementTreeContentHandler()
handler.startDocument() handler.startDocument()
Expand All @@ -233,14 +267,14 @@ def _saxify_unsaxify(self, saxifiable):
handler = sax.ElementTreeContentHandler() handler = sax.ElementTreeContentHandler()
sax.ElementTreeProducer(saxifiable, handler).saxify() sax.ElementTreeProducer(saxifiable, handler).saxify()
return handler.etree return handler.etree

def _saxify_serialize(self, tree): def _saxify_serialize(self, tree):
new_tree = self._saxify_unsaxify(tree) new_tree = self._saxify_unsaxify(tree)
f = BytesIO() f = BytesIO()
new_tree.write(f) new_tree.write(f)
return f.getvalue().replace(_bytes('\n'), _bytes('')) return f.getvalue().replace(_bytes('\n'), _bytes(''))



def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTests([unittest.makeSuite(ETreeSaxTestCase)]) suite.addTests([unittest.makeSuite(ETreeSaxTestCase)])
Expand Down

0 comments on commit 3222e75

Please sign in to comment.