diff --git a/CHANGELOG.md b/CHANGELOG.md index ae29f898..1f0a9738 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Show progress as integer for get_scans. [#269](https://github.com/greenbone/ospd/pull/269) - Make scan_id attribute mandatory for get_scans. [#270](https://github.com/greenbone/ospd/pull/270) - Ignore subsequent SIGINT once inside exit_cleanup(). [#273](https://github.com/greenbone/ospd/pull/273) +- Simplify start_scan() [#275](https://github.com/greenbone/ospd/pull/275) ### Fixed - Fix stop scan. Wait for the scan process to be stopped before delete it from the process table. [#204](https://github.com/greenbone/ospd/pull/204) diff --git a/ospd/ospd.py b/ospd/ospd.py index 6549ca46..88d623cb 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -505,25 +505,21 @@ def handle_client_stream(self, stream: Stream) -> None: stream.close() - def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None: + def process_finished_hosts(self, scan_id: str) -> None: """ Process the finished hosts before launching the scans.""" + finished_hosts = self.scan_collection.get_finished_hosts(scan_id) if not finished_hosts: return exc_finished_hosts_list = target_str_to_list(finished_hosts) self.scan_collection.set_host_finished(scan_id, exc_finished_hosts_list) - def start_scan(self, scan_id: str, target: Dict) -> None: + def start_scan(self, scan_id: str) -> None: """ Starts the scan with scan_id. """ os.setsid() - if target is None or not target: - raise OspdCommandError('Erroneous target', 'start_scan') - - logger.info("%s: Scan started.", scan_id) - - self.process_finished_hosts(scan_id, target.get('finished_hosts')) + self.process_finished_hosts(scan_id) try: self.set_scan_status(scan_id, ScanStatus.RUNNING) @@ -1188,13 +1184,8 @@ def run(self) -> None: def start_pending_scans(self): for scan_id in self.scan_collection.ids_iterator(): if self.get_scan_status(scan_id) == ScanStatus.PENDING: - scan_target = self.scan_collection.scans_table[scan_id].get( - 'target' - ) scan_func = self.start_scan - scan_process = create_process( - func=scan_func, args=(scan_id, scan_target) - ) + scan_process = create_process(func=scan_func, args=(scan_id,)) self.scan_processes[scan_id] = scan_process scan_process.start() self.set_scan_status(scan_id, ScanStatus.INIT) diff --git a/ospd/protocol.py b/ospd/protocol.py index 6f7d9792..1c04e0fe 100644 --- a/ospd/protocol.py +++ b/ospd/protocol.py @@ -47,7 +47,7 @@ class OspRequest: @staticmethod def process_vts_params( scanner_vts: Element, - ) -> Dict[str, Union[Dict, List]]: + ) -> Dict[str, Union[Dict[str, str], List]]: """ Receive an XML object with the Vulnerability Tests an their parameters to be use in a scan and return a dictionary. diff --git a/ospd/scan.py b/ospd/scan.py index 3df8bb4e..3d790651 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -22,7 +22,7 @@ from collections import OrderedDict from enum import Enum -from typing import List, Any, Dict, Iterator, Optional, Iterable +from typing import List, Any, Dict, Iterator, Optional, Iterable, Union from ospd.network import target_str_to_list @@ -347,7 +347,7 @@ def get_host_count(self, scan_id: str) -> int: return total_hosts - def get_ports(self, scan_id: str): + def get_ports(self, scan_id: str) -> str: """ Get a scan's ports list. """ target = self.scans_table[scan_id].get('target') @@ -355,30 +355,30 @@ def get_ports(self, scan_id: str): self.scans_table[scan_id]['target'] = target return ports - def get_exclude_hosts(self, scan_id: str): + def get_exclude_hosts(self, scan_id: str) -> str: """ Get an exclude host list for a given target. """ return self.scans_table[scan_id]['target'].get('exclude_hosts') - def get_finished_hosts(self, scan_id: str): + def get_finished_hosts(self, scan_id: str) -> str: """ Get the finished host list sent by the client for a given target. """ return self.scans_table[scan_id]['target'].get('finished_hosts') - def get_credentials(self, scan_id: str): + def get_credentials(self, scan_id: str) -> Dict[str, Dict[str, str]]: """ Get a scan's credential list. It return dictionary with the corresponding credential for a given target. """ return self.scans_table[scan_id]['target'].get('credentials') - def get_target_options(self, scan_id: str): + def get_target_options(self, scan_id: str) -> Dict[str, str]: """ Get a scan's target option dictionary. It return dictionary with the corresponding options for a given target. """ return self.scans_table[scan_id]['target'].get('options') - def get_vts(self, scan_id: str) -> Dict: + def get_vts(self, scan_id: str) -> Dict[str, Union[Dict[str, str], List]]: """ Get a scan's vts. """ scan_info = self.scans_table[scan_id] vts = scan_info.pop('vts')