From ab7d1caab44e7112c7323e0de478b761b73f848b Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Fri, 7 Feb 2020 16:53:09 +0100 Subject: [PATCH 01/11] Add helper class for converting pieces of strings, list of XML elements in encoded strings representing an xml piece, to be sent to the client. --- ospd/xml.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/ospd/xml.py b/ospd/xml.py index 62cedbed..fc50db69 100644 --- a/ospd/xml.py +++ b/ospd/xml.py @@ -133,3 +133,82 @@ def elements_as_text( text = ''.join([text, ele_txt]) return text + + +class XmlStringHelper: + """ Class with methods to help the creation of a xml object in + string format. + """ + + def create_element(self, elem_name: str, end: bool = False) -> bytes: + """ Get a name and create the open element of an entity. + + Arguments: + elem_name (str): The name of the tag element. + end (bool): Create a initial tag if False, otherwise the end tag. + + Return: + Encoded string representing a part of an xml element. + """ + if end: + ret = "" % elem_name + else: + ret = "<%s>" % elem_name + + return ret.encode() + + def create_response(self, command: str, end: bool = False) -> bytes: + """ Create or end an xml response. + + Arguments: + command (str): The name of the command for the response element. + end (bool): Create a initial tag if False, otherwise the end tag. + + Return: + Encoded string representing a part of an xml element. + """ + if not command: + return + + if end: + return ('' % command).encode() + + return ( + '<%s_response status="200" status_text="OK">' % command + ).encode() + + def add_element( + self, + content: Union[Element, str, list], + xml_str: bytes = None, + end: bool = False, + ) -> bytes: + """Create the initial or ending tag for a subelement, or add + one or many xml elements + + Arguments: + content (Element, str, list): Content to add. + xml_str (bytes): Initial string where content to be added to. + end (bool): Create a initial tag if False, otherwise the end tag. + It will be added to the xml_str. + + Return: + Encoded string representing a part of an xml element. + """ + + if not xml_str: + xml_str = b'' + + if content: + if isinstance(content, list): + for elem in content: + xml_str = xml_str + tostring(elem) + elif isinstance(content, Element): + xml_str = xml_str + tostring(content) + else: + if end: + xml_str = xml_str + self.create_element(content, False) + else: + xml_str = xml_str + self.create_element(content) + + return xml_str From 0e86fc17d8d4d84aaee62d487966f6e35fd966ed Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 08:59:46 +0100 Subject: [PATCH 02/11] Send response piece by piece. Before the response was a single huge xml element in a string, made before sending which caused a huge memory consumption. Now the amount of memory used for this cmd is quite smaller. --- ospd/command/command.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/ospd/command/command.py b/ospd/command/command.py index fe92e2cd..088223a7 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -27,7 +27,11 @@ from ospd.misc import valid_uuid, create_process from ospd.network import target_str_to_list from ospd.protocol import OspRequest, OspResponse -from ospd.xml import simple_response_str, get_elements_from_dict +from ospd.xml import ( + simple_response_str, + get_elements_from_dict, + XmlStringHelper, +) from .initsubclass import InitSubclassMeta from .registry import register_command @@ -285,23 +289,25 @@ class GetVts(BaseCommand): 'filter': 'Optional filter to get an specific vt collection.', } - def handle_xml(self, xml: Element) -> str: + def handle_xml(self, xml: Element, stream) -> str: """ Handles command. + Writes the vt collection on the stream. The element accept two optional arguments. vt_id argument receives a single vt id. filter argument receives a filter selecting a sub set of vts. If both arguments are given, the vts which match with the filter are return. - @return: Response string for command. + @return: Response string for command on fail. """ + xml_helper = XmlStringHelper() vt_id = xml.get('vt_id') vt_filter = xml.get('filter') if vt_id and vt_id not in self._daemon.vts: text = "Failed to find vulnerability test '{0}'".format(vt_id) - return simple_response_str('get_vts', 404, text) + return simple_response_str('get_vts', 404, 'VT Not Found', text) filtered_vts = None if vt_filter: @@ -309,13 +315,14 @@ def handle_xml(self, xml: Element) -> str: self._daemon.vts, vt_filter ) - responses = [] - - vts_xml = self._daemon.get_vts_xml(vt_id, filtered_vts) + stream.write(xml_helper.create_response('get_vts')) + stream.write(xml_helper.create_element('vts')) - responses.append(vts_xml) + for vts_chunk in self._daemon.get_vts_xml(vt_id, filtered_vts): + stream.write(xml_helper.add_element(vts_chunk)) - return simple_response_str('get_vts', 200, 'OK', responses) + stream.write(xml_helper.create_element('vts', end=True)) + stream.write(xml_helper.create_response('get_vts', end=True)) class StopScan(BaseCommand): From b9136af33cea71e53dd0c81cb5607231f3554536 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 09:04:59 +0100 Subject: [PATCH 03/11] Pass the stream object to handle_command() --- ospd/ospd.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index 0364b068..bc93c2b6 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -496,7 +496,7 @@ def handle_client_stream(self, stream) -> None: return try: - response = self.handle_command(data) + response = self.handle_command(data, stream) except OspdCommandError as exception: response = exception.as_xml() logger.debug('Command error: %s', exception.message) @@ -505,7 +505,8 @@ def handle_client_stream(self, stream) -> None: exception = OspdCommandError('Fatal error', 'error') response = exception.as_xml() - stream.write(response) + if response: + stream.write(response) stream.close() def parallel_scan(self, scan_id: str, target: str) -> None: @@ -1180,7 +1181,7 @@ def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): return vts_xml - def handle_command(self, command: str) -> str: + def handle_command(self, command: str, stream) -> str: """ Handles an osp command in a string. @return: OSP Response to command. @@ -1195,6 +1196,9 @@ def handle_command(self, command: str) -> str: if not command and tree.tag != "authenticate": raise OspdCommandError('Bogus command name') + if tree.tag == "get_vts": + return command.handle_xml(tree, stream) + return command.handle_xml(tree) def check(self): From 7c975c13a505a5b1b5ee6b9a6448fe637e58d26f Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 09:06:45 +0100 Subject: [PATCH 04/11] Get a oid list and create a generator to return the vts. Using this generartor reduce the amount of memory used for the get_vts response --- ospd/ospd.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ospd/ospd.py b/ospd/ospd.py index bc93c2b6..65431f08 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1142,7 +1142,8 @@ def get_vt_xml(self, vt_id: str): return vt_xml def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): - """ Gets collection of vulnerability test information in XML format. + """ Python Generator for VTS. + Gets collection of vulnerability test information in XML format. If vt_id is specified, the collection will contain only this vt, if found. If no vt_id is specified or filtered_vts is None (default), the @@ -1152,34 +1153,32 @@ def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): Arguments: vt_id (vt_id, optional): ID of the vt to get. - filtered_vts (dict, optional): Filtered VTs collection. + filtered_vts (list, optional): Filtered VTs collection. Return: String of collection of vulnerability test information in XML format. """ - - vts_xml = Element('vts') - + vts_xml = [] if not self.vts: return vts_xml + # No match for the filter if filtered_vts is not None and len(filtered_vts) == 0: return vts_xml if filtered_vts: - for vt_id in filtered_vts: - vts_xml.append(self.get_vt_xml(vt_id)) + vts_list = filtered_vts elif vt_id: - vts_xml.append(self.get_vt_xml(vt_id)) + vts_list = [vt_id] else: - # Because DictProxy for python3.5 doesn't support + # TODO: Because DictProxy for python3.5 doesn't support # iterkeys(), itervalues(), or iteritems() either, the iteration # must be done as follow. - for vt_id in iter(self.vts.keys()): - vts_xml.append(self.get_vt_xml(vt_id)) + vts_list = iter(self.vts.keys()) - return vts_xml + for vt_id in vts_list: + yield self.get_vt_xml(vt_id) def handle_command(self, command: str, stream) -> str: """ Handles an osp command in a string. From 8bdaec9b46af0b14d87f605ddb4dddc575058410 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 09:10:18 +0100 Subject: [PATCH 05/11] Improve get_filtered_vts_list(). Don't do a copy. Instead use a generator. This reduce considerably the amount of memory used during the filtered vts list. --- ospd/vtfilter.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/ospd/vtfilter.py b/ospd/vtfilter.py index bfe4f8a2..756bd058 100644 --- a/ospd/vtfilter.py +++ b/ospd/vtfilter.py @@ -98,7 +98,9 @@ def format_filter_value(self, element: str, value: Dict): format_func = self.allowed_filter.get(element) return format_func(value) - def get_filtered_vts_list(self, vts: Dict, vt_filter: str) -> Optional[Dict]: + def get_filtered_vts_list( + self, vts: Dict, vt_filter: str + ) -> Optional[Dict]: """ Gets a collection of vulnerability test from the vts dictionary, which match the filter. @@ -107,7 +109,8 @@ def get_filtered_vts_list(self, vts: Dict, vt_filter: str) -> Optional[Dict]: vts (dictionary): The complete vts collection. Returns: - Dictionary with filtered vulnerability tests. + List with filtered vulnerability tests. The list can be empty. + None in case of filter parse failure. """ if not vt_filter: raise OspdCommandError('vt_filter: A valid filter is required.') @@ -116,17 +119,21 @@ def get_filtered_vts_list(self, vts: Dict, vt_filter: str) -> Optional[Dict]: if not filters: return None - _vts_aux = vts.copy() + vt_oid_list = list(vts.keys()) + for _element, _oper, _filter_val in filters: - for vt_id in _vts_aux.copy(): - if not _vts_aux[vt_id].get(_element): - _vts_aux.pop(vt_id) + vts_generator = (vt for vt in vts) + for vt_oid in vts_generator: + if vt_oid not in vt_oid_list: + continue + if not vts[vt_oid].get(_element): + vt_oid_list.remove(vt_oid) continue - _elem_val = _vts_aux[vt_id].get(_element) + _elem_val = vts[vt_oid].get(_element) _val = self.format_filter_value(_element, _elem_val) if self.filter_operator[_oper](_val, _filter_val): continue else: - _vts_aux.pop(vt_id) + vt_oid_list.remove(vt_oid) - return _vts_aux + return vt_oid_list From 1f3d9f01b1376507a55778da39de79b9cbe6875d Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 09:20:22 +0100 Subject: [PATCH 06/11] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98109267..091527fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Modify __init__() method and use new syntax for super(). [#186](https://github.com/greenbone/ospd/pull/186) - Create data manager and spawn new process to keep the vts dictionary. [#191](https://github.com/greenbone/ospd/pull/191) - Update daemon start sequence. Run daemon.check before daemon.init now. [#197](https://github.com/greenbone/ospd/pull/197) +- Improve get_vts cmd response, sending the vts piece by piece.[#201](https://github.com/greenbone/ospd/pull/201) ## [2.0.1] (unreleased) From 4ddae3fa8733a7e863cf6c146f477299bd9fbeeb Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 11:38:09 +0100 Subject: [PATCH 07/11] Import command to initialize the metaclass --- ospd/command/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ospd/command/__init__.py b/ospd/command/__init__.py index 26114314..bae86a6a 100644 --- a/ospd/command/__init__.py +++ b/ospd/command/__init__.py @@ -16,4 +16,5 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. +import ospd.command.command from .registry import get_commands From c2f9c4a8bb6dad08c3f972f1c4efc684924999a6 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 13:36:08 +0100 Subject: [PATCH 08/11] Don't pass stream to handle_xml(). --- ospd/command/command.py | 28 ++++++++++++++++++++-------- ospd/ospd.py | 15 +++++++++------ 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/ospd/command/command.py b/ospd/command/command.py index 088223a7..b0016346 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -19,6 +19,7 @@ import re import subprocess +from types import GeneratorType from typing import Optional, Dict, Any from xml.etree.ElementTree import Element, SubElement @@ -289,7 +290,7 @@ class GetVts(BaseCommand): 'filter': 'Optional filter to get an specific vt collection.', } - def handle_xml(self, xml: Element, stream) -> str: + def handle_xml(self, xml: Element) -> str: """ Handles command. Writes the vt collection on the stream. The element accept two optional arguments. @@ -307,7 +308,6 @@ def handle_xml(self, xml: Element, stream) -> str: if vt_id and vt_id not in self._daemon.vts: text = "Failed to find vulnerability test '{0}'".format(vt_id) - return simple_response_str('get_vts', 404, 'VT Not Found', text) filtered_vts = None if vt_filter: @@ -315,14 +315,26 @@ def handle_xml(self, xml: Element, stream) -> str: self._daemon.vts, vt_filter ) - stream.write(xml_helper.create_response('get_vts')) - stream.write(xml_helper.create_element('vts')) + # Generator + vts_list = (vt for vt in self._daemon.get_vts_xml(vt_id, filtered_vts)) - for vts_chunk in self._daemon.get_vts_xml(vt_id, filtered_vts): - stream.write(xml_helper.add_element(vts_chunk)) + # List of xml pieces with the generator to be iterated + response = [ + xml_helper.create_response('get_vts'), + xml_helper.create_element('vts'), + vts_list, + xml_helper.create_element('vts', end=True), + xml_helper.create_response('get_vts', end=True), + ] - stream.write(xml_helper.create_element('vts', end=True)) - stream.write(xml_helper.create_response('get_vts', end=True)) + for elem in response: + if isinstance(elem, GeneratorType): + for vts_chunk in elem: + yield xml_helper.add_element( + self._daemon.get_vt_xml(vts_chunk) + ) + else: + yield elem class StopScan(BaseCommand): diff --git a/ospd/ospd.py b/ospd/ospd.py index 65431f08..3ca40e07 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -495,8 +495,9 @@ def handle_client_stream(self, stream) -> None: logger.debug("Empty client stream") return + response = None try: - response = self.handle_command(data, stream) + self.handle_command(data, stream) except OspdCommandError as exception: response = exception.as_xml() logger.debug('Command error: %s', exception.message) @@ -1177,8 +1178,7 @@ def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): # must be done as follow. vts_list = iter(self.vts.keys()) - for vt_id in vts_list: - yield self.get_vt_xml(vt_id) + return vts_list def handle_command(self, command: str, stream) -> str: """ Handles an osp command in a string. @@ -1195,10 +1195,13 @@ def handle_command(self, command: str, stream) -> str: if not command and tree.tag != "authenticate": raise OspdCommandError('Bogus command name') - if tree.tag == "get_vts": - return command.handle_xml(tree, stream) + response = command.handle_xml(tree) - return command.handle_xml(tree) + if isinstance(response, bytes): + stream.write(response) + else: + for data in response: + stream.write(data) def check(self): """ Asserts to False. Should be implemented by subclass. """ From a5cc4bf50a20654f4aa2ed13774e2e21b81a318f Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 16:05:33 +0100 Subject: [PATCH 09/11] Fix tests. Add a fake stream --- tests/command/test_commands.py | 6 +- tests/helper.py | 13 + tests/test_scan_and_result.py | 549 ++++++++++++++++++--------------- 3 files changed, 317 insertions(+), 251 deletions(-) diff --git a/tests/command/test_commands.py b/tests/command/test_commands.py index 7c0edc4f..209fee6f 100644 --- a/tests/command/test_commands.py +++ b/tests/command/test_commands.py @@ -24,7 +24,7 @@ from ospd.command.command import GetPerformance, StartScan, StopScan from ospd.errors import OspdCommandError, OspdError -from ..helper import DummyWrapper, assert_called +from ..helper import DummyWrapper, assert_called, FakeStream class GetPerformanceTestCase(TestCase): @@ -261,13 +261,15 @@ def test_stop_scan(self, mock_create_process, mock_os): mock_process.is_alive.return_value = True mock_process.pid = "foo" + fs = FakeStream() daemon = DummyWrapper([]) request = ( '' '' '' ) - response = et.fromstring(daemon.handle_command(request)) + daemon.handle_command(request, fs) + response = fs.get_response() assert_called(mock_create_process) assert_called(mock_process.start) diff --git a/tests/helper.py b/tests/helper.py index 0f6494e1..425a05e5 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -20,6 +20,8 @@ from unittest.mock import Mock +from xml.etree import ElementTree as et + from ospd.ospd import OSPDaemon @@ -36,6 +38,17 @@ def assert_called(mock: Mock): raise AssertionError(msg) +class FakeStream: + def __init__(self): + self.response = b'' + + def write(self, data): + self.response = self.response + data + + def get_response(self): + return et.fromstring(self.response) + + class DummyWrapper(OSPDaemon): def __init__(self, results, checkresult=True): super().__init__() diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index 058b2de5..307047eb 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -31,7 +31,7 @@ from defusedxml.common import EntitiesForbidden -from .helper import DummyWrapper, assert_called +from .helper import DummyWrapper, assert_called, FakeStream class FakeStartProcess: @@ -77,9 +77,10 @@ def __init__(self, type_, **kwargs): class ScanTestCase(unittest.TestCase): def test_get_default_scanner_params(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command('') - ) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() # The status of the response must be success (i.e. 200) self.assertEqual(response.get('status'), '200') @@ -90,35 +91,44 @@ def test_get_default_scanner_params(self): def test_get_default_help(self): daemon = DummyWrapper([]) - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') - response = secET.fromstring( - daemon.handle_command('') - ) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertEqual(response.tag, 'help_response') def test_get_default_scanner_version(self): daemon = DummyWrapper([]) - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertIsNotNone(response.find('protocol')) def test_get_vts_no_vt(self): daemon = DummyWrapper([]) - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertIsNotNone(response.find('vts')) def test_get_vts_single_vt(self): daemon = DummyWrapper([]) + fs = FakeStream() daemon.add_vt('1.2.3.4', 'A vulnerability test') - response = secET.fromstring(daemon.handle_command('')) + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -136,12 +146,12 @@ def test_get_vts_filter_positive(self): vt_params="a", vt_modification_time='19000202', ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - ) + daemon.handle_command( + '', fs ) + response = fs.get_response() self.assertEqual(response.get('status'), '200') vts = response.find('vts') @@ -164,12 +174,12 @@ def test_get_vts_filter_negative(self): vt_params="a", vt_modification_time='19000202', ) - - response = secET.fromstring( - daemon.handle_command( - '' - ) + fs = FakeStream() + daemon.handle_command( + '', fs, ) + response = fs.get_response() + self.assertEqual(response.get('status'), '200') vts = response.find('vts') @@ -190,8 +200,10 @@ def test_get_vtss_multiple_vts(self): daemon.add_vt('1.2.3.5', 'Another vulnerability test') daemon.add_vt('123456789', 'Yet another vulnerability test') - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') vts = response.find('vts') @@ -204,8 +216,11 @@ def test_get_vts_multiple_vts_with_custom(self): '4.3.2.1', 'Another vulnerability test with custom info', custom='b' ) daemon.add_vt('123456789', 'Yet another vulnerability test', custom='b') + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() - response = secET.fromstring(daemon.handle_command('')) custom = response.findall('vts/vt/custom') self.assertEqual(3, len(custom)) @@ -215,10 +230,11 @@ def test_get_vts_vts_with_params(self): daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b" ) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() - response = secET.fromstring( - daemon.handle_command('') - ) # The status of the response must be success (i.e. 200) self.assertEqual(response.get('status'), '200') @@ -245,10 +261,11 @@ def test_get_vts_vts_with_refs(self): custom="b", vt_refs="c", ) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() - response = secET.fromstring( - daemon.handle_command('') - ) # The status of the response must be success (i.e. 200) self.assertEqual(response.get('status'), '200') @@ -276,10 +293,11 @@ def test_get_vts_vts_with_dependencies(self): custom="b", vt_dependencies="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + + response = fs.get_response() deps = response.findall('vts/vt/dependencies/dependency') self.assertEqual(2, len(deps)) @@ -293,10 +311,10 @@ def test_get_vts_vts_with_severities(self): custom="b", severities="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() severity = response.findall('vts/vt/severities/severity') self.assertEqual(1, len(severity)) @@ -311,10 +329,10 @@ def test_get_vts_vts_with_detection_qodt(self): detection="c", qod_t="d", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() detection = response.findall('vts/vt/detection') self.assertEqual(1, len(detection)) @@ -329,10 +347,10 @@ def test_get_vts_vts_with_detection_qodv(self): detection="c", qod_v="d", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() detection = response.findall('vts/vt/detection') self.assertEqual(1, len(detection)) @@ -346,10 +364,10 @@ def test_get_vts_vts_with_summary(self): custom="b", summary="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() summary = response.findall('vts/vt/summary') self.assertEqual(1, len(summary)) @@ -363,10 +381,10 @@ def test_get_vts_vts_with_impact(self): custom="b", impact="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() impact = response.findall('vts/vt/impact') self.assertEqual(1, len(impact)) @@ -380,10 +398,10 @@ def test_get_vts_vts_with_affected(self): custom="b", affected="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() affect = response.findall('vts/vt/affected') self.assertEqual(1, len(affect)) @@ -397,10 +415,10 @@ def test_get_vts_vts_with_insight(self): custom="b", insight="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() insight = response.findall('vts/vt/insight') self.assertEqual(1, len(insight)) @@ -416,10 +434,10 @@ def test_get_vts_vts_with_solution(self): solution_t="d", solution_m="e", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() solution = response.findall('vts/vt/solution') self.assertEqual(1, len(solution)) @@ -432,10 +450,10 @@ def test_get_vts_vts_with_ctime(self): vt_params="a", vt_creation_time='01-01-1900', ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() creation_time = response.findall('vts/vt/creation_time') self.assertEqual( @@ -451,10 +469,10 @@ def test_get_vts_vts_with_mtime(self): vt_params="a", vt_modification_time='02-01-1900', ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() modification_time = response.findall('vts/vt/modification_time') self.assertEqual( @@ -464,22 +482,26 @@ def test_get_vts_vts_with_mtime(self): def test_clean_forgotten_scans(self): daemon = DummyWrapper([]) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - ) + daemon.handle_command( + '', + fs, ) + response = fs.get_response() + scan_id = response.findtext('id') finished = False + while not finished: - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + scans = response.findall('scan') self.assertEqual(1, len(scans)) @@ -492,11 +514,12 @@ def test_clean_forgotten_scans(self): else: finished = True - response = secET.fromstring( + fs = FakeStream() daemon.handle_command( - '' % scan_id + '' % scan_id, fs ) - ) + response = fs.get_response() + self.assertEqual(len(list(daemon.scan_collection.ids_iterator())), 1) # Set an old end_time @@ -514,22 +537,24 @@ def test_clean_forgotten_scans(self): def test_scan_with_error(self): daemon = DummyWrapper([Result('error', value='something went wrong')]) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - ) + daemon.handle_command( + '', + fs, ) + response = fs.get_response() scan_id = response.findtext('id') finished = False while not finished: - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + scans = response.findall('scan') self.assertEqual(1, len(scans)) @@ -542,54 +567,57 @@ def test_scan_with_error(self): else: finished = True - response = secET.fromstring( + fs = FakeStream() + daemon.handle_command( - '' % scan_id + '' % scan_id, fs ) - ) + response = fs.get_response() self.assertEqual( response.findtext('scan/results/result'), 'something went wrong' ) - - response = secET.fromstring( - daemon.handle_command('' % scan_id) - ) + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') def test_get_scan_pop(self): daemon = DummyWrapper([Result('host-detail', value='Some Host Detail')]) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - '' - ) + daemon.handle_command( + '' + '', + fs, ) + response = fs.get_response() scan_id = response.findtext('id') time.sleep(1) - response = secET.fromstring( - daemon.handle_command('' % scan_id) - ) + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) + response = fs.get_response() + self.assertEqual( response.findtext('scan/results/result'), 'Some Host Detail' ) - - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + self.assertEqual( response.findtext('scan/results/result'), 'Some Host Detail' ) - response = secET.fromstring( - daemon.handle_command('') - ) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() + self.assertEqual(response.findtext('scan/results/result'), None) def test_get_scan_pop_max_res(self): @@ -600,30 +628,33 @@ def test_get_scan_pop_max_res(self): Result('host-detail', value='Some Host Detail2'), ] ) - - response = secET.fromstring( - daemon.handle_command( - '' - '' - ) + fs = FakeStream() + daemon.handle_command( + '' + '', + fs, ) + response = fs.get_response() scan_id = response.findtext('id') time.sleep(1) - response = secET.fromstring( - daemon.handle_command( - '' - % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' + % scan_id, + fs, ) + response = fs.get_response() + self.assertEqual(len(response.findall('scan/results/result')), 1) - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + self.assertEqual(len(response.findall('scan/results/result')), 2) def test_billon_laughs(self): @@ -645,49 +676,53 @@ def test_billon_laughs(self): ' ' ']>' ) - self.assertRaises(EntitiesForbidden, daemon.handle_command, lol) + fs = FakeStream() + self.assertRaises(EntitiesForbidden, daemon.handle_command, lol, fs) def test_scan_multi_target(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - 'localhosts' - '80,443' - '0' - '' - '192.168.0.0/24' - '22' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + 'localhosts' + '80,443' + '0' + '' + '192.168.0.0/24' + '22' + '', + fs, + ) + response = fs.get_response() + self.assertEqual(response.get('status'), '200') def test_multi_target_with_credentials(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - 'localhosts' - '80,443' - '192.168.0.0/2422' - '' - '' - 'scanuser' - 'mypass' - '' - 'smbuser' - 'mypass' - '' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'localhosts' + '80,443' + '192.168.0.0/2422' + '' + '' + 'scanuser' + 'mypass' + '' + 'smbuser' + 'mypass' + '' + '' + '', + fs, + ) + response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -706,41 +741,46 @@ def test_multi_target_with_credentials(self): def test_scan_get_target(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - 'localhosts' - '80,443' - '' - '192.168.0.0/24' - '22' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + 'localhosts' + '80,443' + '' + '192.168.0.0/24' + '22' + '', + fs, + ) + response = fs.get_response() scan_id = response.findtext('id') - response = secET.fromstring( - daemon.handle_command('' % scan_id) - ) + + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) + response = fs.get_response() + scan_res = response.find('scan') self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24') def test_scan_get_target_options(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - '192.168.0.1' - '220' - '' - '' - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + '192.168.0.1' + '220' + '' + '', + fs, ) + response = fs.get_response() + scan_id = response.findtext('id') time.sleep(1) target_options = daemon.get_scan_target_options(scan_id, '192.168.0.1') @@ -748,23 +788,25 @@ def test_scan_get_target_options(self): def test_scan_get_finished_hosts(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - '192.168.10.20-25' - '80,443' - '192.168.10.23-24' - '' - '' - '192.168.0.0/24' - '22' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + '192.168.10.20-25' + '80,443' + '192.168.10.23-24' + '' + '' + '192.168.0.0/24' + '22' + '' + '', + fs, + ) + response = fs.get_response() + scan_id = response.findtext('id') time.sleep(1) finished = daemon.get_scan_finished_hosts(scan_id) @@ -773,20 +815,21 @@ def test_scan_get_finished_hosts(self): def test_progress(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - 'localhost1' - '22' - '' - 'localhost2' - '22' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'localhost1' + '22' + '' + 'localhost2' + '22' + '' + '', + fs, + ) + response = fs.get_response() scan_id = response.findtext('id') @@ -827,17 +870,18 @@ def test_resume_task(self, mock_create_process, _mock_os): mock_process.is_alive.return_value = True mock_process.pid = "main-scan-process" - response = ET.fromstring( - daemon.handle_command( - '' - '' - '' - 'localhost' - '22' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'localhost' + '22' + '' + '', + fs, + ) + response = fs.get_response() scan_id = response.findtext('id') self.assertIsNotNone(scan_id) @@ -845,13 +889,14 @@ def test_resume_task(self, mock_create_process, _mock_os): assert_called(mock_create_process) assert_called(mock_process.start) - daemon.handle_command('' % scan_id) + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) - response = ET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() result = response.findall('scan/results/result') self.assertEqual(len(result), 2) @@ -862,7 +907,9 @@ def test_resume_task(self, mock_create_process, _mock_os): '' ''.format(scan_id) ) - response = ET.fromstring(daemon.handle_command(cmd)) + fs = FakeStream() + daemon.handle_command(cmd, fs) + response = fs.get_response() # Check unfinished host self.assertEqual(response.findtext('id'), scan_id) @@ -880,11 +927,11 @@ def test_resume_task(self, mock_create_process, _mock_os): ) # Check if the result was removed. - response = ET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() result = response.findall('scan/results/result') # current the response still contains the results @@ -892,27 +939,31 @@ def test_resume_task(self, mock_create_process, _mock_os): def test_result_order(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - 'a' - '22' - '' - '' - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'a' + '22' + '' + '', + fs, ) + response = fs.get_response() + scan_id = response.findtext('id') daemon.add_scan_log(scan_id, host='a', name='a') daemon.add_scan_log(scan_id, host='c', name='c') daemon.add_scan_log(scan_id, host='b', name='b') hosts = ['a', 'c', 'b'] - response = secET.fromstring( - daemon.handle_command('') - ) + + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() + results = response.findall("scan/results/") for idx, res in enumerate(results): From ab9f88b860c0a9ea7a1e02a86eef2fe30788ccc8 Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Mon, 10 Feb 2020 17:11:39 +0100 Subject: [PATCH 10/11] Use iter to iterate over a DictProxy, since python3.5 does not support direct iteration. --- ospd/vtfilter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ospd/vtfilter.py b/ospd/vtfilter.py index 756bd058..9fcc243e 100644 --- a/ospd/vtfilter.py +++ b/ospd/vtfilter.py @@ -122,7 +122,9 @@ def get_filtered_vts_list( vt_oid_list = list(vts.keys()) for _element, _oper, _filter_val in filters: - vts_generator = (vt for vt in vts) + # Use iter because python3.5 has no support for + # iteration over DictProxy. + vts_generator = (vt for vt in iter(vts.keys())) for vt_oid in vts_generator: if vt_oid not in vt_oid_list: continue From 670ab0a102c1f580ce13fc8eb1404f20113e18fe Mon Sep 17 00:00:00 2001 From: Juan Jose Nicola Date: Tue, 11 Feb 2020 09:17:18 +0100 Subject: [PATCH 11/11] Improve generator --- ospd/command/command.py | 26 ++++++++------------------ ospd/ospd.py | 11 ++++++----- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/ospd/command/command.py b/ospd/command/command.py index b0016346..fe926990 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -308,6 +308,7 @@ def handle_xml(self, xml: Element) -> str: if vt_id and vt_id not in self._daemon.vts: text = "Failed to find vulnerability test '{0}'".format(vt_id) + raise OspdCommandError(text, 'get_vts', 404) filtered_vts = None if vt_filter: @@ -315,26 +316,15 @@ def handle_xml(self, xml: Element) -> str: self._daemon.vts, vt_filter ) - # Generator - vts_list = (vt for vt in self._daemon.get_vts_xml(vt_id, filtered_vts)) - # List of xml pieces with the generator to be iterated - response = [ - xml_helper.create_response('get_vts'), - xml_helper.create_element('vts'), - vts_list, - xml_helper.create_element('vts', end=True), - xml_helper.create_response('get_vts', end=True), - ] + yield xml_helper.create_response('get_vts') + yield xml_helper.create_element('vts') - for elem in response: - if isinstance(elem, GeneratorType): - for vts_chunk in elem: - yield xml_helper.add_element( - self._daemon.get_vt_xml(vts_chunk) - ) - else: - yield elem + for vt in self._daemon.get_vts_selection_list(vt_id, filtered_vts): + yield xml_helper.add_element(self._daemon.get_vt_xml(vt)) + + yield xml_helper.create_element('vts', end=True) + yield xml_helper.create_response('get_vts', end=True) class StopScan(BaseCommand): diff --git a/ospd/ospd.py b/ospd/ospd.py index 3ca40e07..1c1db5d2 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -1142,9 +1142,11 @@ def get_vt_xml(self, vt_id: str): return vt_xml - def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): - """ Python Generator for VTS. - Gets collection of vulnerability test information in XML format. + def get_vts_selection_list( + self, vt_id: str = None, filtered_vts: Dict = None + ) -> List: + """ + Get list of VT's OID. If vt_id is specified, the collection will contain only this vt, if found. If no vt_id is specified or filtered_vts is None (default), the @@ -1157,8 +1159,7 @@ def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): filtered_vts (list, optional): Filtered VTs collection. Return: - String of collection of vulnerability test information in - XML format. + List of selected VT's OID. """ vts_xml = [] if not self.vts: