Skip to content
This repository has been archived by the owner on Nov 29, 2021. It is now read-only.

Simplify start_scan() #275

Merged
merged 5 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Show progress as integer for get_scans. [#269](https://github.com/greenbone/ospd/pull/269)
- Make scan_id attribute mandatory for get_scans. [#270](https://github.com/greenbone/ospd/pull/270)
- Ignore subsequent SIGINT once inside exit_cleanup(). [#273](https://github.com/greenbone/ospd/pull/273)
- Simplify start_scan() [#275](https://github.com/greenbone/ospd/pull/275)

### Fixed
- Fix stop scan. Wait for the scan process to be stopped before delete it from the process table. [#204](https://github.com/greenbone/ospd/pull/204)
Expand Down
19 changes: 5 additions & 14 deletions ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,25 +505,21 @@ def handle_client_stream(self, stream: Stream) -> None:

stream.close()

def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None:
def process_finished_hosts(self, scan_id: str) -> None:
""" Process the finished hosts before launching the scans."""

finished_hosts = self.scan_collection.get_finished_hosts(scan_id)
if not finished_hosts:
return

exc_finished_hosts_list = target_str_to_list(finished_hosts)
self.scan_collection.set_host_finished(scan_id, exc_finished_hosts_list)

def start_scan(self, scan_id: str, target: Dict) -> None:
def start_scan(self, scan_id: str) -> None:
""" Starts the scan with scan_id. """
os.setsid()

if target is None or not target:
raise OspdCommandError('Erroneous target', 'start_scan')

logger.info("%s: Scan started.", scan_id)

self.process_finished_hosts(scan_id, target.get('finished_hosts'))
self.process_finished_hosts(scan_id)

try:
self.set_scan_status(scan_id, ScanStatus.RUNNING)
Expand Down Expand Up @@ -1188,13 +1184,8 @@ def run(self) -> None:
def start_pending_scans(self):
for scan_id in self.scan_collection.ids_iterator():
if self.get_scan_status(scan_id) == ScanStatus.PENDING:
scan_target = self.scan_collection.scans_table[scan_id].get(
'target'
)
scan_func = self.start_scan
scan_process = create_process(
func=scan_func, args=(scan_id, scan_target)
)
scan_process = create_process(func=scan_func, args=(scan_id,))
self.scan_processes[scan_id] = scan_process
scan_process.start()
self.set_scan_status(scan_id, ScanStatus.INIT)
Expand Down
2 changes: 1 addition & 1 deletion ospd/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class OspRequest:
@staticmethod
def process_vts_params(
scanner_vts: Element,
) -> Dict[str, Union[Dict, List]]:
) -> Dict[str, Union[Dict[str, str], List]]:
""" Receive an XML object with the Vulnerability Tests an their
parameters to be use in a scan and return a dictionary.

Expand Down
14 changes: 7 additions & 7 deletions ospd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from collections import OrderedDict
from enum import Enum
from typing import List, Any, Dict, Iterator, Optional, Iterable
from typing import List, Any, Dict, Iterator, Optional, Iterable, Union

from ospd.network import target_str_to_list

Expand Down Expand Up @@ -347,38 +347,38 @@ def get_host_count(self, scan_id: str) -> int:

return total_hosts

def get_ports(self, scan_id: str):
def get_ports(self, scan_id: str) -> str:
""" Get a scan's ports list.
"""
target = self.scans_table[scan_id].get('target')
ports = target.pop('ports')
self.scans_table[scan_id]['target'] = target
return ports

def get_exclude_hosts(self, scan_id: str):
def get_exclude_hosts(self, scan_id: str) -> str:
""" Get an exclude host list for a given target.
"""
return self.scans_table[scan_id]['target'].get('exclude_hosts')

def get_finished_hosts(self, scan_id: str):
def get_finished_hosts(self, scan_id: str) -> str:
""" Get the finished host list sent by the client for a given target.
"""
return self.scans_table[scan_id]['target'].get('finished_hosts')

def get_credentials(self, scan_id: str):
def get_credentials(self, scan_id: str) -> Dict[str, Dict[str, str]]:
""" Get a scan's credential list. It return dictionary with
the corresponding credential for a given target.
"""
return self.scans_table[scan_id]['target'].get('credentials')

def get_target_options(self, scan_id: str):
def get_target_options(self, scan_id: str) -> Dict[str, str]:
""" Get a scan's target option dictionary.
It return dictionary with the corresponding options for
a given target.
"""
return self.scans_table[scan_id]['target'].get('options')

def get_vts(self, scan_id: str) -> Dict:
def get_vts(self, scan_id: str) -> Dict[str, Union[Dict[str, str], List]]:
""" Get a scan's vts. """
scan_info = self.scans_table[scan_id]
vts = scan_info.pop('vts')
Expand Down