Skip to content

Commit

Permalink
core[patch]: Use defusedxml in XMLOutputParser (#19526)
Browse files Browse the repository at this point in the history
This mitigates a security concern for users still using older versions of libexpat that causes an attacker to compromise the availability of the system if an attacker manages to surface malicious payload to this XMLParser.
  • Loading branch information
eyurtsev committed Mar 25, 2024
1 parent e1a6341 commit 727d502
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 36 deletions.
50 changes: 42 additions & 8 deletions libs/core/langchain_core/output_parsers/xml.py
@@ -1,6 +1,7 @@
import re
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from xml.etree import ElementTree as ET
from xml.etree.ElementTree import TreeBuilder

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
Expand Down Expand Up @@ -35,6 +36,10 @@ def get_format_instructions(self) -> str:
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)

def parse(self, text: str) -> Dict[str, List[Any]]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml import ElementTree as DET # type: ignore[import]

# Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
Expand All @@ -46,18 +51,24 @@ def parse(self, text: str) -> Dict[str, List[Any]]:

text = text.strip()
try:
root = ET.fromstring(text)
root = DET.fromstring(text)
return self._root_to_dict(root)

except ET.ParseError as e:
except (DET.ParseError, DET.EntitiesForbidden) as e:
msg = f"Failed to parse XML format from completion {text}. Got: {e}"
raise OutputParserException(msg, llm_output=text) from e

def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]

parser = ET.XMLPullParser(
["start", "end"], _parser=DefusedXMLParser(target=TreeBuilder())
)
xml_start_re = re.compile(r"<[a-zA-Z:_]")
parser = ET.XMLPullParser(["start", "end"])
xml_started = False
current_path: List[str] = []
current_path_has_children = False
Expand All @@ -83,6 +94,7 @@ def _transform(
parser.feed(buffer)
buffer = ""
# yield all events

for event, elem in parser.read_events():
if event == "start":
# update current path
Expand All @@ -105,18 +117,37 @@ def _transform(
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
# Imports are temporarily placed here to avoid issue with caching on CI
# likely if you're reading this you can move them to the top of the file
from defusedxml.ElementTree import DefusedXMLParser # type: ignore[import]

_parser = DefusedXMLParser(target=TreeBuilder())
parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
xml_start_re = re.compile(r"<[a-zA-Z:_]")
xml_started = False
current_path: List[str] = []
current_path_has_children = False
buffer = ""
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# add chunk to buffer of unprocessed text
buffer += chunk
# if xml string hasn't started yet, continue to next chunk
if not xml_started:
if match := xml_start_re.search(buffer):
# if xml string has started, remove all text before it
buffer = buffer[match.start() :]
xml_started = True
else:
continue
# feed buffer to parser
parser.feed(buffer)
buffer = ""
# yield all events
for event, elem in parser.read_events():
if event == "start":
Expand All @@ -130,7 +161,10 @@ async def _atransform(
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
if current_path:
current_path_has_children = True
else:
xml_started = False
# close parser
parser.close()

Expand Down
48 changes: 24 additions & 24 deletions libs/core/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/core/pyproject.toml
Expand Up @@ -19,6 +19,7 @@ PyYAML = ">=5.3"
requests = "^2"
packaging = "^23.2"
jinja2 = { version = "^3", optional = true }
defusedxml = "^0.7"

[tool.poetry.group.lint]
optional = true
Expand Down
59 changes: 55 additions & 4 deletions libs/core/tests/unit_tests/output_parsers/test_xml_parser.py
@@ -1,4 +1,7 @@
"""Test XMLOutputParser"""
from typing import AsyncIterator
from xml.etree.ElementTree import ParseError

import pytest

from langchain_core.exceptions import OutputParserException
Expand Down Expand Up @@ -40,19 +43,29 @@
""",
],
)
def test_xml_output_parser(result: str) -> None:
async def test_xml_output_parser(result: str) -> None:
"""Test XMLOutputParser."""

xml_parser = XMLOutputParser()

xml_result = xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == xml_result
assert DEF_RESULT_EXPECTED == xml_parser.parse(result)
assert DEF_RESULT_EXPECTED == (await xml_parser.aparse(result))
assert list(xml_parser.transform(iter(result))) == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]

async def _as_iter(string: str) -> AsyncIterator[str]:
for c in string:
yield c

chunks = [chunk async for chunk in xml_parser.atransform(_as_iter(result))]
assert chunks == [
{"foo": [{"bar": [{"baz": None}]}]},
{"foo": [{"bar": [{"baz": "slim.shady"}]}]},
{"foo": [{"baz": "tag"}]},
]


@pytest.mark.parametrize("result", ["foo></foo>", "<foo></foo", "foo></foo", "foofoo"])
def test_xml_output_parser_fail(result: str) -> None:
Expand All @@ -63,3 +76,41 @@ def test_xml_output_parser_fail(result: str) -> None:
with pytest.raises(OutputParserException) as e:
xml_parser.parse(result)
assert "Failed to parse" in str(e)


MALICIOUS_XML = """<?xml version="1.0"?>
<!DOCTYPE lolz [<!ENTITY lol "lol"><!ELEMENT lolz (#PCDATA)>
<!ENTITY lol1 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol2 "&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;&lol1;">
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
<!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">
<!ENTITY lol5 "&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;&lol4;">
<!ENTITY lol6 "&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;&lol5;">
<!ENTITY lol7 "&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;&lol6;">
<!ENTITY lol8 "&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;&lol7;">
<!ENTITY lol9 "&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;&lol8;">
]>
<lolz>&lol9;</lolz>"""


async def tests_billion_laughs_attack() -> None:
parser = XMLOutputParser()
with pytest.raises(OutputParserException):
parser.parse(MALICIOUS_XML)

with pytest.raises(OutputParserException):
await parser.aparse(MALICIOUS_XML)

with pytest.raises(ParseError):
# Right now raises undefined entity error
assert list(parser.transform(iter(MALICIOUS_XML))) == [
{"foo": [{"bar": [{"baz": None}]}]}
]

async def _as_iter(string: str) -> AsyncIterator[str]:
for c in string:
yield c

with pytest.raises(ParseError):
chunks = [chunk async for chunk in parser.atransform(_as_iter(MALICIOUS_XML))]
assert chunks == [{"foo": [{"bar": [{"baz": None}]}]}]

1 comment on commit 727d502

@fubuki8087
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still get problem when I run the PoC:

poc

Besides, the unit test tests_billion_laughs_attack fails too:

unit_test

It seems this patch has something wrong and does not work normally.

Please sign in to comment.