diff --git a/CHANGELOG.md b/CHANGELOG.md index 98109267..091527fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Modify __init__() method and use new syntax for super(). [#186](https://github.com/greenbone/ospd/pull/186) - Create data manager and spawn new process to keep the vts dictionary. [#191](https://github.com/greenbone/ospd/pull/191) - Update daemon start sequence. Run daemon.check before daemon.init now. [#197](https://github.com/greenbone/ospd/pull/197) +- Improve get_vts cmd response, sending the vts piece by piece.[#201](https://github.com/greenbone/ospd/pull/201) ## [2.0.1] (unreleased) diff --git a/ospd/command/__init__.py b/ospd/command/__init__.py index 26114314..bae86a6a 100644 --- a/ospd/command/__init__.py +++ b/ospd/command/__init__.py @@ -16,4 +16,5 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. +import ospd.command.command from .registry import get_commands diff --git a/ospd/command/command.py b/ospd/command/command.py index fe92e2cd..fe926990 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -19,6 +19,7 @@ import re import subprocess +from types import GeneratorType from typing import Optional, Dict, Any from xml.etree.ElementTree import Element, SubElement @@ -27,7 +28,11 @@ from ospd.misc import valid_uuid, create_process from ospd.network import target_str_to_list from ospd.protocol import OspRequest, OspResponse -from ospd.xml import simple_response_str, get_elements_from_dict +from ospd.xml import ( + simple_response_str, + get_elements_from_dict, + XmlStringHelper, +) from .initsubclass import InitSubclassMeta from .registry import register_command @@ -287,21 +292,23 @@ class GetVts(BaseCommand): def handle_xml(self, xml: Element) -> str: """ Handles command. + Writes the vt collection on the stream. The element accept two optional arguments. vt_id argument receives a single vt id. filter argument receives a filter selecting a sub set of vts. If both arguments are given, the vts which match with the filter are return. - @return: Response string for command. + @return: Response string for command on fail. """ + xml_helper = XmlStringHelper() vt_id = xml.get('vt_id') vt_filter = xml.get('filter') if vt_id and vt_id not in self._daemon.vts: text = "Failed to find vulnerability test '{0}'".format(vt_id) - return simple_response_str('get_vts', 404, text) + raise OspdCommandError(text, 'get_vts', 404) filtered_vts = None if vt_filter: @@ -309,13 +316,15 @@ def handle_xml(self, xml: Element) -> str: self._daemon.vts, vt_filter ) - responses = [] - - vts_xml = self._daemon.get_vts_xml(vt_id, filtered_vts) + # List of xml pieces with the generator to be iterated + yield xml_helper.create_response('get_vts') + yield xml_helper.create_element('vts') - responses.append(vts_xml) + for vt in self._daemon.get_vts_selection_list(vt_id, filtered_vts): + yield xml_helper.add_element(self._daemon.get_vt_xml(vt)) - return simple_response_str('get_vts', 200, 'OK', responses) + yield xml_helper.create_element('vts', end=True) + yield xml_helper.create_response('get_vts', end=True) class StopScan(BaseCommand): diff --git a/ospd/ospd.py b/ospd/ospd.py index 0364b068..1c1db5d2 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -495,8 +495,9 @@ def handle_client_stream(self, stream) -> None: logger.debug("Empty client stream") return + response = None try: - response = self.handle_command(data) + self.handle_command(data, stream) except OspdCommandError as exception: response = exception.as_xml() logger.debug('Command error: %s', exception.message) @@ -505,7 +506,8 @@ def handle_client_stream(self, stream) -> None: exception = OspdCommandError('Fatal error', 'error') response = exception.as_xml() - stream.write(response) + if response: + stream.write(response) stream.close() def parallel_scan(self, scan_id: str, target: str) -> None: @@ -1140,8 +1142,11 @@ def get_vt_xml(self, vt_id: str): return vt_xml - def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): - """ Gets collection of vulnerability test information in XML format. + def get_vts_selection_list( + self, vt_id: str = None, filtered_vts: Dict = None + ) -> List: + """ + Get list of VT's OID. If vt_id is specified, the collection will contain only this vt, if found. If no vt_id is specified or filtered_vts is None (default), the @@ -1151,36 +1156,32 @@ def get_vts_xml(self, vt_id: str = None, filtered_vts: Dict = None): Arguments: vt_id (vt_id, optional): ID of the vt to get. - filtered_vts (dict, optional): Filtered VTs collection. + filtered_vts (list, optional): Filtered VTs collection. Return: - String of collection of vulnerability test information in - XML format. + List of selected VT's OID. """ - - vts_xml = Element('vts') - + vts_xml = [] if not self.vts: return vts_xml + # No match for the filter if filtered_vts is not None and len(filtered_vts) == 0: return vts_xml if filtered_vts: - for vt_id in filtered_vts: - vts_xml.append(self.get_vt_xml(vt_id)) + vts_list = filtered_vts elif vt_id: - vts_xml.append(self.get_vt_xml(vt_id)) + vts_list = [vt_id] else: - # Because DictProxy for python3.5 doesn't support + # TODO: Because DictProxy for python3.5 doesn't support # iterkeys(), itervalues(), or iteritems() either, the iteration # must be done as follow. - for vt_id in iter(self.vts.keys()): - vts_xml.append(self.get_vt_xml(vt_id)) + vts_list = iter(self.vts.keys()) - return vts_xml + return vts_list - def handle_command(self, command: str) -> str: + def handle_command(self, command: str, stream) -> str: """ Handles an osp command in a string. @return: OSP Response to command. @@ -1195,7 +1196,13 @@ def handle_command(self, command: str) -> str: if not command and tree.tag != "authenticate": raise OspdCommandError('Bogus command name') - return command.handle_xml(tree) + response = command.handle_xml(tree) + + if isinstance(response, bytes): + stream.write(response) + else: + for data in response: + stream.write(data) def check(self): """ Asserts to False. Should be implemented by subclass. """ diff --git a/ospd/vtfilter.py b/ospd/vtfilter.py index bfe4f8a2..9fcc243e 100644 --- a/ospd/vtfilter.py +++ b/ospd/vtfilter.py @@ -98,7 +98,9 @@ def format_filter_value(self, element: str, value: Dict): format_func = self.allowed_filter.get(element) return format_func(value) - def get_filtered_vts_list(self, vts: Dict, vt_filter: str) -> Optional[Dict]: + def get_filtered_vts_list( + self, vts: Dict, vt_filter: str + ) -> Optional[Dict]: """ Gets a collection of vulnerability test from the vts dictionary, which match the filter. @@ -107,7 +109,8 @@ def get_filtered_vts_list(self, vts: Dict, vt_filter: str) -> Optional[Dict]: vts (dictionary): The complete vts collection. Returns: - Dictionary with filtered vulnerability tests. + List with filtered vulnerability tests. The list can be empty. + None in case of filter parse failure. """ if not vt_filter: raise OspdCommandError('vt_filter: A valid filter is required.') @@ -116,17 +119,23 @@ def get_filtered_vts_list(self, vts: Dict, vt_filter: str) -> Optional[Dict]: if not filters: return None - _vts_aux = vts.copy() + vt_oid_list = list(vts.keys()) + for _element, _oper, _filter_val in filters: - for vt_id in _vts_aux.copy(): - if not _vts_aux[vt_id].get(_element): - _vts_aux.pop(vt_id) + # Use iter because python3.5 has no support for + # iteration over DictProxy. + vts_generator = (vt for vt in iter(vts.keys())) + for vt_oid in vts_generator: + if vt_oid not in vt_oid_list: + continue + if not vts[vt_oid].get(_element): + vt_oid_list.remove(vt_oid) continue - _elem_val = _vts_aux[vt_id].get(_element) + _elem_val = vts[vt_oid].get(_element) _val = self.format_filter_value(_element, _elem_val) if self.filter_operator[_oper](_val, _filter_val): continue else: - _vts_aux.pop(vt_id) + vt_oid_list.remove(vt_oid) - return _vts_aux + return vt_oid_list diff --git a/ospd/xml.py b/ospd/xml.py index 62cedbed..fc50db69 100644 --- a/ospd/xml.py +++ b/ospd/xml.py @@ -133,3 +133,82 @@ def elements_as_text( text = ''.join([text, ele_txt]) return text + + +class XmlStringHelper: + """ Class with methods to help the creation of a xml object in + string format. + """ + + def create_element(self, elem_name: str, end: bool = False) -> bytes: + """ Get a name and create the open element of an entity. + + Arguments: + elem_name (str): The name of the tag element. + end (bool): Create a initial tag if False, otherwise the end tag. + + Return: + Encoded string representing a part of an xml element. + """ + if end: + ret = "" % elem_name + else: + ret = "<%s>" % elem_name + + return ret.encode() + + def create_response(self, command: str, end: bool = False) -> bytes: + """ Create or end an xml response. + + Arguments: + command (str): The name of the command for the response element. + end (bool): Create a initial tag if False, otherwise the end tag. + + Return: + Encoded string representing a part of an xml element. + """ + if not command: + return + + if end: + return ('' % command).encode() + + return ( + '<%s_response status="200" status_text="OK">' % command + ).encode() + + def add_element( + self, + content: Union[Element, str, list], + xml_str: bytes = None, + end: bool = False, + ) -> bytes: + """Create the initial or ending tag for a subelement, or add + one or many xml elements + + Arguments: + content (Element, str, list): Content to add. + xml_str (bytes): Initial string where content to be added to. + end (bool): Create a initial tag if False, otherwise the end tag. + It will be added to the xml_str. + + Return: + Encoded string representing a part of an xml element. + """ + + if not xml_str: + xml_str = b'' + + if content: + if isinstance(content, list): + for elem in content: + xml_str = xml_str + tostring(elem) + elif isinstance(content, Element): + xml_str = xml_str + tostring(content) + else: + if end: + xml_str = xml_str + self.create_element(content, False) + else: + xml_str = xml_str + self.create_element(content) + + return xml_str diff --git a/tests/command/test_commands.py b/tests/command/test_commands.py index 7c0edc4f..209fee6f 100644 --- a/tests/command/test_commands.py +++ b/tests/command/test_commands.py @@ -24,7 +24,7 @@ from ospd.command.command import GetPerformance, StartScan, StopScan from ospd.errors import OspdCommandError, OspdError -from ..helper import DummyWrapper, assert_called +from ..helper import DummyWrapper, assert_called, FakeStream class GetPerformanceTestCase(TestCase): @@ -261,13 +261,15 @@ def test_stop_scan(self, mock_create_process, mock_os): mock_process.is_alive.return_value = True mock_process.pid = "foo" + fs = FakeStream() daemon = DummyWrapper([]) request = ( '' '' '' ) - response = et.fromstring(daemon.handle_command(request)) + daemon.handle_command(request, fs) + response = fs.get_response() assert_called(mock_create_process) assert_called(mock_process.start) diff --git a/tests/helper.py b/tests/helper.py index 0f6494e1..425a05e5 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -20,6 +20,8 @@ from unittest.mock import Mock +from xml.etree import ElementTree as et + from ospd.ospd import OSPDaemon @@ -36,6 +38,17 @@ def assert_called(mock: Mock): raise AssertionError(msg) +class FakeStream: + def __init__(self): + self.response = b'' + + def write(self, data): + self.response = self.response + data + + def get_response(self): + return et.fromstring(self.response) + + class DummyWrapper(OSPDaemon): def __init__(self, results, checkresult=True): super().__init__() diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index 058b2de5..307047eb 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -31,7 +31,7 @@ from defusedxml.common import EntitiesForbidden -from .helper import DummyWrapper, assert_called +from .helper import DummyWrapper, assert_called, FakeStream class FakeStartProcess: @@ -77,9 +77,10 @@ def __init__(self, type_, **kwargs): class ScanTestCase(unittest.TestCase): def test_get_default_scanner_params(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command('') - ) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() # The status of the response must be success (i.e. 200) self.assertEqual(response.get('status'), '200') @@ -90,35 +91,44 @@ def test_get_default_scanner_params(self): def test_get_default_help(self): daemon = DummyWrapper([]) - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') - response = secET.fromstring( - daemon.handle_command('') - ) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertEqual(response.tag, 'help_response') def test_get_default_scanner_version(self): daemon = DummyWrapper([]) - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertIsNotNone(response.find('protocol')) def test_get_vts_no_vt(self): daemon = DummyWrapper([]) - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') self.assertIsNotNone(response.find('vts')) def test_get_vts_single_vt(self): daemon = DummyWrapper([]) + fs = FakeStream() daemon.add_vt('1.2.3.4', 'A vulnerability test') - response = secET.fromstring(daemon.handle_command('')) + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -136,12 +146,12 @@ def test_get_vts_filter_positive(self): vt_params="a", vt_modification_time='19000202', ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - ) + daemon.handle_command( + '', fs ) + response = fs.get_response() self.assertEqual(response.get('status'), '200') vts = response.find('vts') @@ -164,12 +174,12 @@ def test_get_vts_filter_negative(self): vt_params="a", vt_modification_time='19000202', ) - - response = secET.fromstring( - daemon.handle_command( - '' - ) + fs = FakeStream() + daemon.handle_command( + '', fs, ) + response = fs.get_response() + self.assertEqual(response.get('status'), '200') vts = response.find('vts') @@ -190,8 +200,10 @@ def test_get_vtss_multiple_vts(self): daemon.add_vt('1.2.3.5', 'Another vulnerability test') daemon.add_vt('123456789', 'Yet another vulnerability test') - response = secET.fromstring(daemon.handle_command('')) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') vts = response.find('vts') @@ -204,8 +216,11 @@ def test_get_vts_multiple_vts_with_custom(self): '4.3.2.1', 'Another vulnerability test with custom info', custom='b' ) daemon.add_vt('123456789', 'Yet another vulnerability test', custom='b') + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() - response = secET.fromstring(daemon.handle_command('')) custom = response.findall('vts/vt/custom') self.assertEqual(3, len(custom)) @@ -215,10 +230,11 @@ def test_get_vts_vts_with_params(self): daemon.add_vt( '1.2.3.4', 'A vulnerability test', vt_params="a", custom="b" ) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() - response = secET.fromstring( - daemon.handle_command('') - ) # The status of the response must be success (i.e. 200) self.assertEqual(response.get('status'), '200') @@ -245,10 +261,11 @@ def test_get_vts_vts_with_refs(self): custom="b", vt_refs="c", ) + fs = FakeStream() + + daemon.handle_command('', fs) + response = fs.get_response() - response = secET.fromstring( - daemon.handle_command('') - ) # The status of the response must be success (i.e. 200) self.assertEqual(response.get('status'), '200') @@ -276,10 +293,11 @@ def test_get_vts_vts_with_dependencies(self): custom="b", vt_dependencies="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + + response = fs.get_response() deps = response.findall('vts/vt/dependencies/dependency') self.assertEqual(2, len(deps)) @@ -293,10 +311,10 @@ def test_get_vts_vts_with_severities(self): custom="b", severities="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() severity = response.findall('vts/vt/severities/severity') self.assertEqual(1, len(severity)) @@ -311,10 +329,10 @@ def test_get_vts_vts_with_detection_qodt(self): detection="c", qod_t="d", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() detection = response.findall('vts/vt/detection') self.assertEqual(1, len(detection)) @@ -329,10 +347,10 @@ def test_get_vts_vts_with_detection_qodv(self): detection="c", qod_v="d", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() detection = response.findall('vts/vt/detection') self.assertEqual(1, len(detection)) @@ -346,10 +364,10 @@ def test_get_vts_vts_with_summary(self): custom="b", summary="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() summary = response.findall('vts/vt/summary') self.assertEqual(1, len(summary)) @@ -363,10 +381,10 @@ def test_get_vts_vts_with_impact(self): custom="b", impact="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() impact = response.findall('vts/vt/impact') self.assertEqual(1, len(impact)) @@ -380,10 +398,10 @@ def test_get_vts_vts_with_affected(self): custom="b", affected="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() affect = response.findall('vts/vt/affected') self.assertEqual(1, len(affect)) @@ -397,10 +415,10 @@ def test_get_vts_vts_with_insight(self): custom="b", insight="c", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() insight = response.findall('vts/vt/insight') self.assertEqual(1, len(insight)) @@ -416,10 +434,10 @@ def test_get_vts_vts_with_solution(self): solution_t="d", solution_m="e", ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() solution = response.findall('vts/vt/solution') self.assertEqual(1, len(solution)) @@ -432,10 +450,10 @@ def test_get_vts_vts_with_ctime(self): vt_params="a", vt_creation_time='01-01-1900', ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() creation_time = response.findall('vts/vt/creation_time') self.assertEqual( @@ -451,10 +469,10 @@ def test_get_vts_vts_with_mtime(self): vt_params="a", vt_modification_time='02-01-1900', ) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command('') - ) + daemon.handle_command('', fs) + response = fs.get_response() modification_time = response.findall('vts/vt/modification_time') self.assertEqual( @@ -464,22 +482,26 @@ def test_get_vts_vts_with_mtime(self): def test_clean_forgotten_scans(self): daemon = DummyWrapper([]) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - ) + daemon.handle_command( + '', + fs, ) + response = fs.get_response() + scan_id = response.findtext('id') finished = False + while not finished: - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + scans = response.findall('scan') self.assertEqual(1, len(scans)) @@ -492,11 +514,12 @@ def test_clean_forgotten_scans(self): else: finished = True - response = secET.fromstring( + fs = FakeStream() daemon.handle_command( - '' % scan_id + '' % scan_id, fs ) - ) + response = fs.get_response() + self.assertEqual(len(list(daemon.scan_collection.ids_iterator())), 1) # Set an old end_time @@ -514,22 +537,24 @@ def test_clean_forgotten_scans(self): def test_scan_with_error(self): daemon = DummyWrapper([Result('error', value='something went wrong')]) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - ) + daemon.handle_command( + '', + fs, ) + response = fs.get_response() scan_id = response.findtext('id') finished = False while not finished: - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + scans = response.findall('scan') self.assertEqual(1, len(scans)) @@ -542,54 +567,57 @@ def test_scan_with_error(self): else: finished = True - response = secET.fromstring( + fs = FakeStream() + daemon.handle_command( - '' % scan_id + '' % scan_id, fs ) - ) + response = fs.get_response() self.assertEqual( response.findtext('scan/results/result'), 'something went wrong' ) - - response = secET.fromstring( - daemon.handle_command('' % scan_id) - ) + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) + response = fs.get_response() self.assertEqual(response.get('status'), '200') def test_get_scan_pop(self): daemon = DummyWrapper([Result('host-detail', value='Some Host Detail')]) + fs = FakeStream() - response = secET.fromstring( - daemon.handle_command( - '' - '' - ) + daemon.handle_command( + '' + '', + fs, ) + response = fs.get_response() scan_id = response.findtext('id') time.sleep(1) - response = secET.fromstring( - daemon.handle_command('' % scan_id) - ) + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) + response = fs.get_response() + self.assertEqual( response.findtext('scan/results/result'), 'Some Host Detail' ) - - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + self.assertEqual( response.findtext('scan/results/result'), 'Some Host Detail' ) - response = secET.fromstring( - daemon.handle_command('') - ) + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() + self.assertEqual(response.findtext('scan/results/result'), None) def test_get_scan_pop_max_res(self): @@ -600,30 +628,33 @@ def test_get_scan_pop_max_res(self): Result('host-detail', value='Some Host Detail2'), ] ) - - response = secET.fromstring( - daemon.handle_command( - '' - '' - ) + fs = FakeStream() + daemon.handle_command( + '' + '', + fs, ) + response = fs.get_response() scan_id = response.findtext('id') time.sleep(1) - response = secET.fromstring( - daemon.handle_command( - '' - % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' + % scan_id, + fs, ) + response = fs.get_response() + self.assertEqual(len(response.findall('scan/results/result')), 1) - response = secET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() + self.assertEqual(len(response.findall('scan/results/result')), 2) def test_billon_laughs(self): @@ -645,49 +676,53 @@ def test_billon_laughs(self): ' ' ']>' ) - self.assertRaises(EntitiesForbidden, daemon.handle_command, lol) + fs = FakeStream() + self.assertRaises(EntitiesForbidden, daemon.handle_command, lol, fs) def test_scan_multi_target(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - 'localhosts' - '80,443' - '0' - '' - '192.168.0.0/24' - '22' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + 'localhosts' + '80,443' + '0' + '' + '192.168.0.0/24' + '22' + '', + fs, + ) + response = fs.get_response() + self.assertEqual(response.get('status'), '200') def test_multi_target_with_credentials(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - 'localhosts' - '80,443' - '192.168.0.0/2422' - '' - '' - 'scanuser' - 'mypass' - '' - 'smbuser' - 'mypass' - '' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'localhosts' + '80,443' + '192.168.0.0/2422' + '' + '' + 'scanuser' + 'mypass' + '' + 'smbuser' + 'mypass' + '' + '' + '', + fs, + ) + response = fs.get_response() self.assertEqual(response.get('status'), '200') @@ -706,41 +741,46 @@ def test_multi_target_with_credentials(self): def test_scan_get_target(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - 'localhosts' - '80,443' - '' - '192.168.0.0/24' - '22' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + 'localhosts' + '80,443' + '' + '192.168.0.0/24' + '22' + '', + fs, + ) + response = fs.get_response() scan_id = response.findtext('id') - response = secET.fromstring( - daemon.handle_command('' % scan_id) - ) + + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) + response = fs.get_response() + scan_res = response.find('scan') self.assertEqual(scan_res.get('target'), 'localhosts,192.168.0.0/24') def test_scan_get_target_options(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - '192.168.0.1' - '220' - '' - '' - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + '192.168.0.1' + '220' + '' + '', + fs, ) + response = fs.get_response() + scan_id = response.findtext('id') time.sleep(1) target_options = daemon.get_scan_target_options(scan_id, '192.168.0.1') @@ -748,23 +788,25 @@ def test_scan_get_target_options(self): def test_scan_get_finished_hosts(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - '' - '192.168.10.20-25' - '80,443' - '192.168.10.23-24' - '' - '' - '192.168.0.0/24' - '22' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + '' + '192.168.10.20-25' + '80,443' + '192.168.10.23-24' + '' + '' + '192.168.0.0/24' + '22' + '' + '', + fs, + ) + response = fs.get_response() + scan_id = response.findtext('id') time.sleep(1) finished = daemon.get_scan_finished_hosts(scan_id) @@ -773,20 +815,21 @@ def test_scan_get_finished_hosts(self): def test_progress(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - 'localhost1' - '22' - '' - 'localhost2' - '22' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'localhost1' + '22' + '' + 'localhost2' + '22' + '' + '', + fs, + ) + response = fs.get_response() scan_id = response.findtext('id') @@ -827,17 +870,18 @@ def test_resume_task(self, mock_create_process, _mock_os): mock_process.is_alive.return_value = True mock_process.pid = "main-scan-process" - response = ET.fromstring( - daemon.handle_command( - '' - '' - '' - 'localhost' - '22' - '' - '' - ) - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'localhost' + '22' + '' + '', + fs, + ) + response = fs.get_response() scan_id = response.findtext('id') self.assertIsNotNone(scan_id) @@ -845,13 +889,14 @@ def test_resume_task(self, mock_create_process, _mock_os): assert_called(mock_create_process) assert_called(mock_process.start) - daemon.handle_command('' % scan_id) + fs = FakeStream() + daemon.handle_command('' % scan_id, fs) - response = ET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() result = response.findall('scan/results/result') self.assertEqual(len(result), 2) @@ -862,7 +907,9 @@ def test_resume_task(self, mock_create_process, _mock_os): '' ''.format(scan_id) ) - response = ET.fromstring(daemon.handle_command(cmd)) + fs = FakeStream() + daemon.handle_command(cmd, fs) + response = fs.get_response() # Check unfinished host self.assertEqual(response.findtext('id'), scan_id) @@ -880,11 +927,11 @@ def test_resume_task(self, mock_create_process, _mock_os): ) # Check if the result was removed. - response = ET.fromstring( - daemon.handle_command( - '' % scan_id - ) + fs = FakeStream() + daemon.handle_command( + '' % scan_id, fs ) + response = fs.get_response() result = response.findall('scan/results/result') # current the response still contains the results @@ -892,27 +939,31 @@ def test_resume_task(self, mock_create_process, _mock_os): def test_result_order(self): daemon = DummyWrapper([]) - response = secET.fromstring( - daemon.handle_command( - '' - '' - '' - 'a' - '22' - '' - '' - ) + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'a' + '22' + '' + '', + fs, ) + response = fs.get_response() + scan_id = response.findtext('id') daemon.add_scan_log(scan_id, host='a', name='a') daemon.add_scan_log(scan_id, host='c', name='c') daemon.add_scan_log(scan_id, host='b', name='b') hosts = ['a', 'c', 'b'] - response = secET.fromstring( - daemon.handle_command('') - ) + + fs = FakeStream() + daemon.handle_command('', fs) + response = fs.get_response() + results = response.findall("scan/results/") for idx, res in enumerate(results):